Unverified Commit ddf80f5e authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation

fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents 4d0483f5 c048ca80
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("logout-pending-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("logout-subject-123").
SetBrowserSessionKey("logout-browser-session-key").
SetResolvedEmail("logout@example.com").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
cookies := recorder.Result().Cookies()
for _, name := range []string{
oauthPendingSessionCookieName,
oauthPendingBrowserCookieName,
oauthBindAccessTokenCookieName,
linuxDoOAuthStateCookieName,
oidcOAuthStateCookieName,
wechatOAuthStateCookieName,
wechatPaymentOAuthStateName,
} {
cookie := findCookie(cookies, name)
require.NotNil(t, cookie, name)
require.Equal(t, -1, cookie.MaxAge, name)
require.True(t, cookie.HttpOnly, name)
}
storedSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
}
......@@ -265,16 +265,20 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool {
if len(payload) == 0 {
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
if session == nil {
return false
}
for _, key := range []string{"access_token", "refresh_token"} {
if value := pendingSessionStringValue(payload, key); value != "" {
return true
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
return false
}
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
return false
}
if pendingSessionWantsInvitation(payload) {
return false
}
return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
}
func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
......@@ -294,6 +298,78 @@ func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSes
return nil
}
func buildLegacyCompleteRegistrationPendingResponse(
session *dbent.PendingAuthSession,
forceEmailOnSignup bool,
emailVerificationRequired bool,
) map[string]any {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
"step": oauthPendingChoiceStep,
"adoption_required": true,
"create_account_allowed": true,
"force_email_on_signup": forceEmailOnSignup,
}))
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
if _, exists := completionResponse["email"]; !exists {
completionResponse["email"] = email
}
if _, exists := completionResponse["resolved_email"]; !exists {
completionResponse["resolved_email"] = email
}
}
if _, exists := completionResponse["choice_reason"]; !exists {
switch {
case forceEmailOnSignup:
completionResponse["choice_reason"] = "force_email_on_signup"
case emailVerificationRequired:
completionResponse["choice_reason"] = "email_verification_required"
default:
completionResponse["choice_reason"] = "third_party_signup"
}
}
return completionResponse
}
func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
c *gin.Context,
session *dbent.PendingAuthSession,
) (*dbent.PendingAuthSession, bool, error) {
if session == nil {
return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
if step := pendingSessionStringValue(payload, "step"); step != "" {
return session, true, nil
}
emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
if !emailVerificationRequired && !forceEmailOnSignup {
return session, false, nil
}
client := h.entClient()
if client == nil {
return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
updatedSession, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
strings.TrimSpace(session.ResolvedEmail),
nil,
buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
)
if err != nil {
return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
}
return updatedSession, true, nil
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
......@@ -376,15 +452,7 @@ func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity servic
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
userEntity, err := client.User.Get(ctx, record.UserID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
}
return userEntity, nil
return findActiveUserByID(ctx, client, record.UserID)
}
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
......@@ -439,7 +507,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
email := strings.TrimSpace(strings.ToLower(req.Email))
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -624,6 +692,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
return matches[0], nil
}
func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
if client == nil || session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity == nil || identity.UserID <= 0 {
return nil
}
activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
if err != nil {
return err
}
if activeOwner != nil {
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
return nil
}
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
if session == nil {
return nil
......@@ -910,6 +1010,9 @@ func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64)
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
}
if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) {
return nil, service.ErrUserNotActive
}
return userEntity, nil
}
......@@ -1123,6 +1226,38 @@ func consumePendingOAuthBrowserSessionTx(
return nil
}
func applyPendingOAuthAdoptionAndConsumeSession(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
userID int64,
) error {
if client == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
if session == nil || userID <= 0 {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
return err
}
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
return err
}
return tx.Commit()
}
func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
......@@ -1212,13 +1347,7 @@ func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
if session == nil || len(payload) == 0 {
return false, nil
}
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
return false, nil
}
if !pendingOAuthCompletionIncludesTokenPayload(payload) {
return false, nil
}
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
return false, nil
}
if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
......@@ -1262,6 +1391,59 @@ func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.Au
return svc, session, clearCookies, nil
}
func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
if c == nil || c.Request == nil {
return
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
return
}
svc, err := h.pendingIdentityService()
if err != nil {
return
}
_, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
}
func clearOAuthLogoutCookies(c *gin.Context) {
secureCookie := isRequestHTTPS(c)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
clearOAuthBindAccessTokenCookie(c, secureCookie)
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
}
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
payload := gin.H{
......@@ -1280,6 +1462,9 @@ func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gi
func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
normalized := clonePendingMap(payload)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
switch step {
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
......@@ -1315,16 +1500,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
c *gin.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
targetUser *dbent.User,
email string,
) (*dbent.PendingAuthSession, error) {
completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
var targetUserID *int64
if targetUser != nil && targetUser.ID > 0 {
targetUserID = &targetUser.ID
}
session, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
email,
nil,
targetUserID,
completionResponse,
)
if err != nil {
......@@ -1438,6 +1628,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
......@@ -1464,7 +1658,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
}
if existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -1487,7 +1681,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
)
if err != nil {
if errors.Is(err, service.ErrEmailExists) {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email)
existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
if lookupErr != nil {
response.ErrorFrom(c, lookupErr)
return
}
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
......@@ -1649,33 +1848,35 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
}
}
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
var loginUser *service.User
if canIssueTokenPair {
loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if skipAdoptionPrompt {
delete(payload, "adoption_required")
}
if pendingOAuthCompletionIncludesTokenPayload(payload) {
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
if err := ensureLoginUserActive(loginUser); err != nil {
clearCookies()
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
response.ErrorFrom(c, err)
return
}
user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
}
skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if skipAdoptionPrompt {
delete(payload, "adoption_required")
}
if pendingSessionWantsInvitation(payload) {
......@@ -1724,6 +1925,20 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
return
}
if canIssueTokenPair {
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
if err != nil {
clearCookies()
response.InternalError(c, "Failed to generate token pair")
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
payload["access_token"] = tokenPair.AccessToken
payload["refresh_token"] = tokenPair.RefreshToken
payload["expires_in"] = tokenPair.ExpiresIn
payload["token_type"] = "Bearer"
}
clearCookies()
response.Success(c, payload)
}
......@@ -746,8 +746,8 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"refresh_token": "refresh-token",
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
......@@ -769,13 +769,23 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
require.Equal(t, http.StatusOK, recorder.Code)
payload := decodeJSONResponseData(t, recorder)
require.Equal(t, "access-token", payload["access_token"])
require.Equal(t, "refresh-token", payload["refresh_token"])
require.NotEmpty(t, payload["access_token"])
require.NotEmpty(t, payload["refresh_token"])
require.NotEqual(t, "legacy-access-token", payload["access_token"])
require.NotEqual(t, "legacy-refresh-token", payload["refresh_token"])
require.Equal(t, "/dashboard", payload["redirect"])
require.Equal(t, "Existing Login Example", payload["suggested_display_name"])
require.Equal(t, "https://cdn.example/existing-login.png", payload["suggested_avatar_url"])
require.NotContains(t, payload, "adoption_required")
accessToken, ok := payload["access_token"].(string)
require.True(t, ok)
claims, err := handler.authService.ValidateToken(accessToken)
require.NoError(t, err)
reloadedUser, err := handler.userService.GetByID(ctx, userEntity.ID)
require.NoError(t, err)
require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
decisionCount, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
......@@ -785,6 +795,14 @@ func TestExchangePendingOAuthCompletionExistingLoginWithSuggestedProfileSkipsAdo
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
completion, ok := storedSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.NotContains(t, completion, "access_token")
require.NotContains(t, completion, "refresh_token")
require.NotContains(t, completion, "expires_in")
require.NotContains(t, completion, "token_type")
require.Equal(t, "/dashboard", completion["redirect"])
}
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
......@@ -841,6 +859,72 @@ func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayl
require.Nil(t, storedSession.ConsumedAt)
}
func TestExchangePendingOAuthCompletionRejectsDisabledTargetUser(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
userEntity, err := client.User.Create().
SetEmail("disabled-linked@example.com").
SetUsername("disabled-linked-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("disabled-linked-session-token").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("disabled-linked-subject").
SetTargetUserID(userEntity.ID).
SetResolvedEmail(userEntity.Email).
SetBrowserSessionKey("disabled-linked-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"suggested_display_name": "Disabled Linked User",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"redirect": "/dashboard",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("disabled-linked-browser-session-key")})
ginCtx.Request = req
handler.ExchangePendingOAuthCompletion(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestNormalizePendingOAuthCompletionResponseScrubsLegacyTokenPayload(t *testing.T) {
payload := normalizePendingOAuthCompletionResponse(map[string]any{
"access_token": "legacy-access-token",
"refresh_token": "legacy-refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
"redirect": "/dashboard",
})
require.NotContains(t, payload, "access_token")
require.NotContains(t, payload, "refresh_token")
require.NotContains(t, payload, "expires_in")
require.NotContains(t, payload, "token_type")
require.Equal(t, "/dashboard", payload["redirect"])
}
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background()
......@@ -969,7 +1053,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background()
_, err := client.User.Create().
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
......@@ -1023,7 +1107,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, storedSession.Intent)
require.Nil(t, storedSession.TargetUserID)
require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
require.Nil(t, storedSession.ConsumedAt)
......@@ -1042,7 +1127,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background()
_, err := client.User.Create().
existingUser, err := client.User.Create().
SetEmail(" Owner@Example.com ").
SetUsername("owner-user").
SetPasswordHash("hash").
......@@ -1088,7 +1173,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.TargetUserID)
require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
......@@ -1096,7 +1182,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background()
_, err := client.User.Create().
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
......@@ -1144,7 +1230,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, storedSession.Intent)
require.Nil(t, storedSession.TargetUserID)
require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
}
......@@ -1202,6 +1289,26 @@ func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T)
require.Nil(t, storedSession.ConsumedAt)
}
func TestLogoutClearsPendingOAuthAndBindCookies(t *testing.T) {
handler, _ := newOAuthPendingFlowTestHandler(t, false)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue("pending-session-token")})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("pending-browser-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-token"})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthPendingBrowserCookieName).MaxAge)
require.Equal(t, -1, findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName).MaxAge)
}
func TestCreateOIDCOAuthAccountRollsBackCreatedUserWhenBindingFails(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, true, "fresh@example.com", "246810")
ctx := context.Background()
......@@ -1934,6 +2041,13 @@ func TestLogin2FACompletesPendingOAuthBindAndConsumesSession(t *testing.T) {
payload := decodeJSONResponseData(t, recorder)
require.NotEmpty(t, payload["access_token"])
require.NotEmpty(t, payload["refresh_token"])
accessToken, ok := payload["access_token"].(string)
require.True(t, ok)
claims, err := handler.authService.ValidateToken(accessToken)
require.NoError(t, err)
reloadedUser, err := handler.userService.GetByID(ctx, existingUser.ID)
require.NoError(t, err)
require.Equal(t, reloadedUser.TokenVersion, claims.TokenVersion)
identity, err := client.AuthIdentity.Query().
Where(
......
......@@ -2,6 +2,7 @@ package handler
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/require"
......@@ -37,3 +38,20 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
require.NoError(t, err)
return decoded
}
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
t.Helper()
require.NotEmpty(t, location)
parsed, err := url.Parse(location)
require.NoError(t, err)
rawValues := parsed.RawQuery
if rawValues == "" {
rawValues = parsed.Fragment
}
values, err := url.ParseQuery(rawValues)
require.NoError(t, err)
require.Equal(t, errorCode, values.Get("error"))
require.Equal(t, errorMessage, values.Get("error_message"))
}
......@@ -157,6 +157,7 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
}
codeChallenge := ""
if cfg.UsePKCE {
verifier, genErr := oauth.GenerateCodeVerifier()
if genErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(genErr))
......@@ -164,14 +165,17 @@ func (h *AuthHandler) OIDCOAuthStart(c *gin.Context) {
}
codeChallenge = oauth.GenerateCodeChallenge(verifier)
oidcSetCookie(c, oidcOAuthVerifierCookie, encodeCookieValue(verifier), oidcOAuthCookieMaxAgeSec, secureCookie)
}
nonce := ""
if cfg.ValidateIDToken {
nonce, err = oauth.GenerateState()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_NONCE_GEN_FAILED", "failed to generate oauth nonce").WithCause(err))
return
}
oidcSetCookie(c, oidcOAuthNonceCookie, encodeCookieValue(nonce), oidcOAuthCookieMaxAgeSec, secureCookie)
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
......@@ -244,18 +248,22 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
intent = normalizeOAuthIntent(intent)
codeVerifier := ""
if cfg.UsePKCE {
codeVerifier, _ = readCookieDecoded(c, oidcOAuthVerifierCookie)
if codeVerifier == "" {
redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "")
return
}
}
expectedNonce := ""
if cfg.ValidateIDToken {
expectedNonce, _ = readCookieDecoded(c, oidcOAuthNonceCookie)
if expectedNonce == "" {
redirectOAuthError(c, frontendCallback, "missing_nonce", "missing oauth nonce", "")
return
}
}
redirectURI := strings.TrimSpace(cfg.RedirectURL)
if redirectURI == "" {
......@@ -284,17 +292,20 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
var idClaims *oidcIDTokenClaims
if cfg.ValidateIDToken {
if strings.TrimSpace(tokenResp.IDToken) == "" {
redirectOAuthError(c, frontendCallback, "missing_id_token", "missing id_token", "")
return
}
idClaims, err := oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
idClaims, err = oidcParseAndValidateIDToken(c.Request.Context(), cfg, tokenResp.IDToken, expectedNonce)
if err != nil {
log.Printf("[OIDC OAuth] id_token validation failed: %v", err)
redirectOAuthError(c, frontendCallback, "invalid_id_token", "failed to validate id_token", "")
return
}
}
userInfoClaims, err := oidcFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
......@@ -303,7 +314,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
subject := strings.TrimSpace(idClaims.Subject)
subject := ""
if idClaims != nil {
subject = strings.TrimSpace(idClaims.Subject)
}
if subject == "" {
subject = strings.TrimSpace(userInfoClaims.Subject)
}
......@@ -311,7 +325,10 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "missing_subject", "missing subject claim", "")
return
}
issuer := strings.TrimSpace(idClaims.Issuer)
issuer := ""
if idClaims != nil {
issuer = strings.TrimSpace(idClaims.Issuer)
}
if issuer == "" {
issuer = strings.TrimSpace(cfg.IssuerURL)
}
......@@ -321,21 +338,34 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
emailVerified := userInfoClaims.EmailVerified
if emailVerified == nil {
if emailVerified == nil && idClaims != nil {
emailVerified = idClaims.EmailVerified
}
if userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
if idClaims != nil && userInfoClaims.Subject != "" && idClaims.Subject != "" && strings.TrimSpace(userInfoClaims.Subject) != strings.TrimSpace(idClaims.Subject) {
redirectOAuthError(c, frontendCallback, "subject_mismatch", "userinfo subject does not match id_token", "")
return
}
identityKey := oidcIdentityKey(issuer, subject)
compatEmail := strings.TrimSpace(firstNonEmpty(userInfoClaims.Email, idClaims.Email))
compatEmail := strings.TrimSpace(userInfoClaims.Email)
if compatEmail == "" && idClaims != nil {
compatEmail = strings.TrimSpace(idClaims.Email)
}
email := oidcSyntheticEmailFromIdentityKey(identityKey)
username := firstNonEmpty(
userInfoClaims.Username,
idClaims.PreferredUsername,
idClaims.Name,
func() string {
if idClaims != nil {
return idClaims.PreferredUsername
}
return ""
}(),
func() string {
if idClaims != nil {
return idClaims.Name
}
return ""
}(),
oidcFallbackUsername(subject),
)
identityRef := service.PendingAuthIdentityKey{
......@@ -350,7 +380,12 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
"issuer": issuer,
"email_verified": emailVerified != nil && *emailVerified,
"provider_fallback": strings.TrimSpace(cfg.ProviderName),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username),
"suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, func() string {
if idClaims != nil {
return idClaims.Name
}
return ""
}(), username),
"suggested_avatar_url": userInfoClaims.AvatarURL,
}
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
......@@ -387,24 +422,15 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
return
}
if existingIdentityUser != nil {
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityRef,
TargetUserID: &user.ID,
TargetUserID: &existingIdentityUser.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
"redirect": redirectTo,
},
}); err != nil {
......@@ -537,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
if compatEmailUser != nil {
resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
}
var targetUserID *int64
if compatEmailUser != nil && compatEmailUser.ID > 0 {
targetUserID = &compatEmailUser.ID
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identity,
TargetUserID: targetUserID,
ResolvedEmail: resolvedChoiceEmail,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
......@@ -596,6 +627,15 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
......@@ -608,12 +648,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
......@@ -621,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
......@@ -670,7 +713,9 @@ func oidcExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
r := client.R().
SetContext(ctx).
......@@ -872,9 +917,13 @@ func buildOIDCAuthorizeURL(cfg config.OIDCConnectConfig, state, nonce, codeChall
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if strings.TrimSpace(nonce) != "" {
q.Set("nonce", nonce)
}
if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
u.RawQuery = q.Encode()
return u.String(), nil
......
......@@ -186,6 +186,89 @@ func TestOIDCOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
require.Equal(t, int64(84), userID)
}
func TestOIDCOAuthStartOmitsPKCEAndNonceWhenDisabled(t *testing.T) {
handler := newOIDCOAuthTestHandler(t, false, config.OIDCConnectConfig{
Enabled: true,
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/oauth/authorize",
TokenURL: "https://issuer.example.com/oauth/token",
UserInfoURL: "https://issuer.example.com/oauth/userinfo",
Scopes: "openid profile email",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
ValidateIDToken: false,
RequireEmailVerified: false,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/start?redirect=/dashboard", nil)
handler.OIDCOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.NotContains(t, location, "code_challenge=")
require.NotContains(t, location, "nonce=")
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthVerifierCookie))
require.Nil(t, findCookie(recorder.Result().Cookies(), oidcOAuthNonceCookie))
}
func TestOIDCOAuthCallbackAllowsOptionalPKCEAndIDTokenValidation(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
require.NoError(t, r.ParseForm())
require.Empty(t, r.PostForm.Get("code_verifier"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"oidc-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"sub":"oidc-subject-compat","preferred_username":"oidc_user","name":"OIDC Display","email":"oidc@example.com"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newOIDCOAuthHandlerAndClient(t, false, config.OIDCConnectConfig{
Enabled: true,
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "openid profile email",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
ValidateIDToken: false,
RequireEmailVerified: false,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-123", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/oidc/callback", recorder.Header().Get("Location"))
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
}
func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-login",
......@@ -250,10 +333,63 @@ func TestOIDCOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *t
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.NotEmpty(t, completion["access_token"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
require.Nil(t, completion["error"])
}
func TestOIDCOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-disabled-subject",
PreferredUsername: "oidc_disabled",
DisplayName: "OIDC Disabled",
})
defer cleanup()
handler, client := newOIDCOAuthHandlerAndClient(t, false, cfg)
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(oidcSyntheticEmailFromIdentityKey(oidcIdentityKey(cfg.IssuerURL, "oidc-disabled-subject"))).
SetUsername("disabled-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("oidc").
SetProviderKey(cfg.IssuerURL).
SetProviderSubject("oidc-disabled-subject").
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-disabled", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-disabled"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-disabled"))
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-disabled-subject"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-subject-compat",
......@@ -302,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.Nil(t, session.TargetUserID)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
......@@ -606,6 +743,189 @@ func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteOIDCOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-choice-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-choice-subject-1").
SetResolvedEmail("oidc-choice-subject-1@oidc-connect.invalid").
SetBrowserSessionKey("oidc-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-choice-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-no-adoption-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-subject-no-adoption").
SetResolvedEmail("8c9f12b2a2e14b1db9efc08b27e0ef5c@oidc-connect.invalid").
SetBrowserSessionKey("oidc-browser-no-adoption").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
"suggested_display_name": "OIDC Legacy",
"suggested_avatar_url": "https://cdn.example/oidc-legacy.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser-no-adoption")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
require.NotEmpty(t, responseData["refresh_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "oidc_user", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example.com"),
authidentity.ProviderSubjectEQ("oidc-subject-no-adoption"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
}
func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingOwner.ID).
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-conflict-subject").
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-conflict-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-conflict-subject").
SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid").
SetBrowserSessionKey("oidc-conflict-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusConflict, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
userCount, err := client.User.Query().
Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
type oidcProviderFixture struct {
Subject string
PreferredUsername string
......
//go:build unit
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAuthHandlerRevokeAllSessionsInvalidatesAccessTokens(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 29,
Email: "session@example.com",
Username: "session-user",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 7,
},
}
refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := &AuthHandler{authService: authService}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/auth/revoke-all-sessions", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 29})
handler.RevokeAllSessions(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, []int64{29}, refreshTokenCache.revokedUserIDs)
require.Equal(t, int64(8), repo.user.TokenVersion)
var resp struct {
Code int `json:"code"`
Data struct {
Message string `json:"message"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "All sessions have been revoked. Please log in again.", resp.Data.Message)
}
......@@ -279,12 +279,7 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "")
if err != nil {
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil {
if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, nil, nil, &existingIdentityUser.ID); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
......@@ -476,11 +471,12 @@ func (h *AuthHandler) WeChatPaymentOAuthCallback(c *gin.Context) {
}
func (h *AuthHandler) wechatPaymentResumeService() *service.PaymentResumeService {
var legacyKey []byte
key, err := payment.ProvideEncryptionKey(h.cfg)
if err != nil {
return service.NewPaymentResumeService(nil)
if err == nil {
legacyKey = []byte(key)
}
return service.NewPaymentResumeService([]byte(key))
return service.NewLegacyAwarePaymentResumeService(legacyKey)
}
type completeWeChatOAuthRequest struct {
......@@ -530,6 +526,15 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
......@@ -547,7 +552,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
......@@ -823,8 +828,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if user, err := singleWeChatIdentityUser(records); err != nil || user != nil {
if err != nil || user == nil {
return user, err
}
return findActiveUserByID(ctx, client, user.ID)
}
}
openid = strings.TrimSpace(openid)
......@@ -847,8 +855,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
if user, err := singleWeChatChannelUser(records); err != nil || user != nil {
if err != nil || user == nil {
return user, err
}
return findActiveUserByID(ctx, client, user.ID)
}
}
if openid == "" {
......@@ -866,7 +877,11 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID(
if err != nil {
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
return singleWeChatIdentityUser(records)
user, err := singleWeChatIdentityUser(records)
if err != nil || user == nil {
return user, err
}
return findActiveUserByID(ctx, client, user.ID)
}
func wechatCompatibleProviderKeys(providerKey string) []string {
......
......@@ -213,6 +213,151 @@ func TestWeChatOAuthCallbackFallsBackToOpenIDWhenUnionIDMissingInSingleChannelMo
require.Equal(t, "third_party_signup", completion["choice_reason"])
}
func TestWeChatOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUserWithoutStoredTokens(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat-login.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback"))
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(wechatSyntheticEmail("union-456")).
SetUsername("wechat-existing-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-456").
SetMetadata(map[string]any{"username": "wechat-existing-user"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "https://app.example.com/auth/wechat/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail)
completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.Equal(t, "/dashboard", completion["redirect"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
}
func TestWeChatOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
wechatOAuthUserInfoURL = originalUserInfoURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-disabled","unionid":"union-disabled","scope":"snsapi_login"}`))
case strings.Contains(r.URL.Path, "/sns/userinfo"):
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"openid":"openid-disabled","unionid":"union-disabled","nickname":"Disabled WeChat","headimgurl":"https://cdn.example/disabled.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo"
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(wechatSyntheticEmail("union-disabled")).
SetUsername("disabled-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("union-disabled").
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-disabled", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-disabled"))
req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard"))
req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open"))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
c.Request = req
handler.WeChatOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {
......@@ -233,6 +378,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
defer client.Close()
handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
handler.cfg.Totp.EncryptionKeyConfigured = true
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
......@@ -270,6 +416,67 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T)
require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
}
func TestWeChatPaymentOAuthCallbackUsesExplicitPaymentResumeSigningKeyWhenMixedKeysConfigured(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
t.Cleanup(func() {
wechatOAuthAccessTokenURL = originalAccessTokenURL
})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/sns/oauth2/access_token") {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-mixed-key","scope":"snsapi_base"}`))
return
}
http.NotFound(w, r)
}))
defer upstream.Close()
wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token"
handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback"))
defer client.Close()
legacyKeyHex := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
explicitSigningKey := "explicit-payment-resume-signing-key"
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", explicitSigningKey)
handler.cfg.Totp.EncryptionKey = legacyKeyHex
handler.cfg.Totp.EncryptionKeyConfigured = true
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/payment/callback?code=wechat-code&state=state-mixed", nil)
req.Host = "api.example.com"
req.AddCookie(encodedCookie(wechatPaymentOAuthStateName, "state-mixed"))
req.AddCookie(encodedCookie(wechatPaymentOAuthRedirect, "/purchase?from=wechat"))
req.AddCookie(encodedCookie(wechatPaymentOAuthContextName, `{"payment_type":"wxpay","amount":"18.8","order_type":"subscription","plan_id":9}`))
req.AddCookie(encodedCookie(wechatPaymentOAuthScope, "snsapi_base"))
c.Request = req
handler.WeChatPaymentOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
parsed, err := url.Parse(location)
require.NoError(t, err)
fragment, err := url.ParseQuery(parsed.Fragment)
require.NoError(t, err)
token := fragment.Get("wechat_resume_token")
require.NotEmpty(t, token)
claims, err := service.NewPaymentResumeService([]byte(explicitSigningKey)).ParseWeChatPaymentResumeToken(token)
require.NoError(t, err)
require.Equal(t, "openid-mixed-key", claims.OpenID)
require.Equal(t, payment.TypeWxpay, claims.PaymentType)
require.Equal(t, "18.8", claims.Amount)
require.Equal(t, payment.OrderTypeSubscription, claims.OrderType)
require.EqualValues(t, 9, claims.PlanID)
require.Equal(t, "/purchase?from=wechat", claims.RedirectTo)
_, err = service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")).ParseWeChatPaymentResumeToken(token)
require.Error(t, err)
}
func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) {
testCases := []struct {
name string
......@@ -620,7 +827,7 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes
require.Zero(t, count)
}
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) {
func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSessionReturnsPendingSession(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
t.Cleanup(func() {
......@@ -693,27 +900,32 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require.Equal(t, http.StatusOK, completeRecorder.Code)
responseData := decodeJSONBody(t, completeRecorder)
require.NotEmpty(t, responseData["access_token"])
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["adoption_required"])
require.Empty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")).
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "WeChat Display", userEntity.Username)
require.Nil(t, consumed.ConsumedAt)
identity, err := client.AuthIdentity.Query().
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-456"),
).
Only(ctx)
Count(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "WeChat Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"])
require.Zero(t, identityCount)
channel, err := client.AuthIdentityChannel.Query().
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"),
......@@ -721,25 +933,82 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
authidentitychannel.ChannelAppIDEQ("wx-open-app"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
Count(ctx)
require.NoError(t, err)
require.Equal(t, identity.ID, channel.IdentityID)
require.Equal(t, "union-456", channel.Metadata["unionid"])
require.Zero(t, channelCount)
decision, err := client.IdentityAdoptionDecision.Query().
decisionCount, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)).
Count(ctx)
require.NoError(t, err)
require.Zero(t, decisionCount)
}
func TestCompleteWeChatOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-no-adoption-session").
SetIntent("login").
SetProviderType("wechat").
SetProviderKey(wechatOAuthProviderKey).
SetProviderSubject("wechat-subject-no-adoption").
SetResolvedEmail("wechat-subject-no-adoption@wechat-connect.invalid").
SetBrowserSessionKey("wechat-browser-no-adoption").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
"suggested_display_name": "WeChat Legacy",
"suggested_avatar_url": "https://cdn.example/wechat-legacy.png",
"mode": "open",
"channel": "open",
"channel_app_id": "wx-open-app",
"channel_subject": "openid-legacy",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
completeReq.Header.Set("Content-Type", "application/json")
completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-browser-no-adoption")})
completeCtx.Request = completeReq
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
require.NotEmpty(t, responseData["refresh_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
require.Equal(t, "wechat_user", userEntity.Username)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(pendingSession.ID)).
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ(wechatOAuthProviderKey),
authidentity.ProviderSubjectEQ("wechat-subject-no-adoption"),
).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
}
func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
......@@ -901,6 +1170,62 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteWeChatOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-choice-session").
SetIntent("login").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("wechat-choice-subject-1").
SetResolvedEmail("wechat-choice-subject-1@wechat-connect.invalid").
SetBrowserSessionKey("wechat-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-choice-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
originalAccessTokenURL := wechatOAuthAccessTokenURL
originalUserInfoURL := wechatOAuthUserInfoURL
......@@ -1083,18 +1408,6 @@ func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool,
}, client
}
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
t.Helper()
parsed, err := url.Parse(location)
require.NoError(t, err)
fragment, err := url.ParseQuery(parsed.Fragment)
require.NoError(t, err)
require.Equal(t, errorCode, fragment.Get("error"))
require.Equal(t, errorMessage, fragment.Get("error_message"))
}
type wechatOAuthSettingRepoStub struct {
values map[string]string
}
......
......@@ -2,9 +2,9 @@ package handler
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/payment"
......@@ -458,25 +458,61 @@ type PublicOrderResult struct {
OutTradeNo string `json:"out_trade_no"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
PaidAt *time.Time `json:"paid_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
RefundAmount float64 `json:"refund_amount"`
RefundReason *string `json:"refund_reason,omitempty"`
RefundRequestedAt *time.Time `json:"refund_requested_at,omitempty"`
RefundRequestedBy *string `json:"refund_requested_by,omitempty"`
RefundRequestReason *string `json:"refund_request_reason,omitempty"`
PlanID *int64 `json:"plan_id,omitempty"`
}
var errPaymentPublicOrderVerifyRemoved = infraerrors.New(
http.StatusGone,
"PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED",
"public payment order verification by out_trade_no has been removed; use resume_token recovery instead",
).WithMetadata(map[string]string{
"replacement_endpoint": "/api/v1/payment/public/orders/resolve",
"replacement_field": "resume_token",
})
// VerifyOrderPublic is kept as a compatibility shim for the removed anonymous
// out_trade_no lookup endpoint and always returns HTTP 410 Gone.
func buildPublicOrderResult(order *dbent.PaymentOrder) PublicOrderResult {
return PublicOrderResult{
ID: order.ID,
OutTradeNo: order.OutTradeNo,
Amount: order.Amount,
PayAmount: order.PayAmount,
FeeRate: order.FeeRate,
PaymentType: order.PaymentType,
OrderType: order.OrderType,
Status: order.Status,
CreatedAt: order.CreatedAt,
ExpiresAt: order.ExpiresAt,
PaidAt: order.PaidAt,
CompletedAt: order.CompletedAt,
RefundAmount: order.RefundAmount,
RefundReason: order.RefundReason,
RefundRequestedAt: order.RefundRequestedAt,
RefundRequestedBy: order.RefundRequestedBy,
RefundRequestReason: order.RefundRequestReason,
PlanID: order.PlanID,
}
}
// VerifyOrderPublic keeps the legacy anonymous out_trade_no lookup available as
// a compatibility path for older result pages and staggered deploys.
// POST /api/v1/payment/public/orders/verify
func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
response.ErrorFrom(c, errPaymentPublicOrderVerifyRemoved)
var req VerifyOrderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
order, err := h.paymentService.VerifyOrderPublic(c.Request.Context(), req.OutTradeNo)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, buildPublicOrderResult(order))
}
// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token.
......@@ -493,15 +529,7 @@ func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Success(c, PublicOrderResult{
ID: order.ID,
OutTradeNo: order.OutTradeNo,
Amount: order.Amount,
PayAmount: order.PayAmount,
PaymentType: order.PaymentType,
OrderType: order.OrderType,
Status: order.Status,
})
response.Success(c, buildPublicOrderResult(order))
}
// requireAuth extracts the authenticated subject from the context.
......
......@@ -4,16 +4,17 @@ package handler
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
......@@ -74,7 +75,7 @@ func TestApplyWeChatPaymentResumeClaimsRejectsPaymentTypeMismatch(t *testing.T)
}
}
func TestVerifyOrderPublicReturnsGone(t *testing.T) {
func TestVerifyOrderPublicReturnsLegacyOrderState(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
......@@ -90,6 +91,32 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-verify@example.com").
SetPasswordHash("hash").
SetUsername("public-verify-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(88).
SetPayAmount(90.64).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-VERIFY").
SetOutTradeNo("legacy-order-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-verify").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPending).
SetExpiresAt(time.Now().Add(time.Hour)).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
......@@ -104,11 +131,238 @@ func TestVerifyOrderPublicReturnsGone(t *testing.T) {
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusGone, recorder.Code)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
ID int64 `json:"id"`
OutTradeNo string `json:"out_trade_no"`
Amount float64 `json:"amount"`
PayAmount float64 `json:"pay_amount"`
FeeRate float64 `json:"fee_rate"`
PaymentType string `json:"payment_type"`
OrderType string `json:"order_type"`
Status string `json:"status"`
RefundAmount float64 `json:"refund_amount"`
CreatedAt string `json:"created_at"`
ExpiresAt string `json:"expires_at"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, order.ID, resp.Data.ID)
require.Equal(t, "legacy-order-no", resp.Data.OutTradeNo)
require.Equal(t, 90.64, resp.Data.PayAmount)
require.Equal(t, 0.03, resp.Data.FeeRate)
require.Equal(t, payment.TypeAlipay, resp.Data.PaymentType)
require.Equal(t, payment.OrderTypeBalance, resp.Data.OrderType)
require.Equal(t, service.OrderStatusPending, resp.Data.Status)
require.Equal(t, 0.0, resp.Data.RefundAmount)
require.NotEmpty(t, resp.Data.CreatedAt)
require.NotEmpty(t, resp.Data.ExpiresAt)
}
func TestResolveOrderPublicByResumeTokenReturnsFrontendContractFields(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
var resp response.Response
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-resolve@example.com").
SetPasswordHash("hash").
SetUsername("public-resolve-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(100).
SetPayAmount(103).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-RESOLVE").
SetOutTradeNo("resolve-order-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-resolve").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPaid).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/resolve",
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.ResolveOrderPublicByResumeToken(ctx)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, float64(order.ID), resp.Data["id"])
require.Equal(t, "resolve-order-no", resp.Data["out_trade_no"])
require.Equal(t, 100.0, resp.Data["amount"])
require.Equal(t, 103.0, resp.Data["pay_amount"])
require.Equal(t, 0.03, resp.Data["fee_rate"])
require.Equal(t, payment.TypeAlipay, resp.Data["payment_type"])
require.Equal(t, payment.OrderTypeBalance, resp.Data["order_type"])
require.Equal(t, service.OrderStatusPaid, resp.Data["status"])
require.Contains(t, resp.Data, "created_at")
require.Contains(t, resp.Data, "expires_at")
require.Contains(t, resp.Data, "refund_amount")
}
func TestResolveOrderPublicByResumeTokenReturnsBadRequestForMismatchedToken(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
db, err := sql.Open("sqlite", "file:payment_handler_public_resolve_mismatch?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("public-resolve-mismatch@example.com").
SetPasswordHash("hash").
SetUsername("public-resolve-mismatch-user").
Save(context.Background())
require.NoError(t, err)
order, err := client.PaymentOrder.Create().
SetUserID(user.ID).
SetUserEmail(user.Email).
SetUserName(user.Username).
SetAmount(100).
SetPayAmount(103).
SetFeeRate(0.03).
SetRechargeCode("PUBLIC-RESOLVE-MISMATCH").
SetOutTradeNo("resolve-order-mismatch-no").
SetPaymentType(payment.TypeAlipay).
SetPaymentTradeNo("trade-public-resolve-mismatch").
SetOrderType(payment.OrderTypeBalance).
SetStatus(service.OrderStatusPaid).
SetExpiresAt(time.Now().Add(time.Hour)).
SetPaidAt(time.Now()).
SetClientIP("127.0.0.1").
SetSrcHost("api.example.com").
Save(context.Background())
require.NoError(t, err)
resumeSvc := service.NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := resumeSvc.CreateToken(service.ResumeTokenClaims{
OrderID: order.ID,
UserID: user.ID + 999,
PaymentType: payment.TypeAlipay,
CanonicalReturnURL: "https://app.example.com/payment/result",
})
require.NoError(t, err)
configSvc := service.NewPaymentConfigService(client, nil, []byte("0123456789abcdef0123456789abcdef"))
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, configSvc, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/resolve",
bytes.NewBufferString(`{"resume_token":"`+token+`"}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.ResolveOrderPublicByResumeToken(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
Message string `json:"message"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_RESUME_TOKEN", resp.Reason)
}
func TestVerifyOrderPublicRejectsBlankOutTradeNo(t *testing.T) {
gin.SetMode(gin.TestMode)
db, err := sql.Open("sqlite", "file:payment_handler_public_verify_blank?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
paymentSvc := service.NewPaymentService(client, payment.NewRegistry(), nil, nil, nil, nil, nil, nil)
h := NewPaymentHandler(paymentSvc, nil, nil)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(
http.MethodPost,
"/api/v1/payment/public/orders/verify",
bytes.NewBufferString(`{"out_trade_no":" "}`),
)
ctx.Request.Header.Set("Content-Type", "application/json")
h.VerifyOrderPublic(ctx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
var resp struct {
Code int `json:"code"`
Reason string `json:"reason"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, http.StatusGone, resp.Code)
require.Equal(t, "PAYMENT_PUBLIC_ORDER_VERIFY_REMOVED", resp.Reason)
require.Contains(t, resp.Message, "removed")
require.Equal(t, http.StatusBadRequest, resp.Code)
require.Equal(t, "INVALID_OUT_TRADE_NO", resp.Reason)
}
......@@ -249,7 +249,7 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
return
}
updatedUser, err := h.userService.UnbindUserAuthProvider(
updatedUser, unbound, err := h.userService.UnbindUserAuthProviderWithResult(
c.Request.Context(),
subject.UserID,
c.Param("provider"),
......@@ -258,6 +258,12 @@ func (h *UserHandler) UnbindIdentity(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if unbound && h.authService != nil {
if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
response.ErrorFrom(c, err)
return
}
}
profileResp, err := h.buildUserProfileResponse(c.Request.Context(), subject.UserID, updatedUser)
if err != nil {
......@@ -504,8 +510,12 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
thirdParty := thirdPartyIdentityProviders(identities)
var avatarSource *userProfileSourceContext
if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 {
avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider)
avatarValue := strings.TrimSpace(user.AvatarURL)
for _, summary := range thirdParty {
if avatarValue != "" && avatarValue == strings.TrimSpace(summary.AvatarURL) {
avatarSource = buildUserProfileSourceContext(summary.Provider)
break
}
}
usernameValue := strings.TrimSpace(user.Username)
......@@ -516,9 +526,6 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
break
}
}
if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 {
usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider)
}
profileSources := map[string]*userProfileSourceContext{}
if avatarSource != nil {
......
......@@ -253,7 +253,7 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) {
require.Equal(t, "https://issuer.example.com", resp.Data.Identities.OIDC.ProviderKey)
require.False(t, resp.Data.Identities.WeChat.Bound)
require.True(t, resp.Data.Identities.WeChat.CanBind)
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start")
require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/bind/start")
}
func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
......@@ -278,6 +278,7 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
},
......@@ -331,10 +332,102 @@ func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) {
require.Equal(t, "linuxdo", usernameSource["source"])
}
func TestUserHandlerGetProfileDoesNotInferEditedProfileSourcesWithoutMatchingIdentityMetadata(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 22,
Email: "edited-profile@example.com",
Username: "custom-name",
Role: service.RoleUser,
Status: service.StatusActive,
AvatarURL: "https://cdn.example.com/custom.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-22",
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 22})
handler.GetProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.NotContains(t, resp.Data, "avatar_source")
require.NotContains(t, resp.Data, "username_source")
require.NotContains(t, resp.Data, "profile_sources")
}
type userHandlerEmailCacheStub struct {
data *service.VerificationCodeData
}
type userHandlerRefreshTokenCacheStub struct {
revokedUserIDs []int64
}
func (s *userHandlerRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) {
return nil, service.ErrRefreshTokenNotFound
}
func (s *userHandlerRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) DeleteUserRefreshTokens(_ context.Context, userID int64) error {
s.revokedUserIDs = append(s.revokedUserIDs, userID)
return nil
}
func (s *userHandlerRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error {
return nil
}
func (s *userHandlerRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) {
return nil, nil
}
func (s *userHandlerRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) {
return nil, nil
}
func (s *userHandlerRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) {
return false, nil
}
func (s *userHandlerEmailCacheStub) GetVerificationCode(context.Context, string) (*service.VerificationCodeData, error) {
return s.data, nil
}
......@@ -495,6 +588,98 @@ func TestUserHandlerUnbindIdentityReturnsUpdatedProfile(t *testing.T) {
require.Equal(t, false, linuxdoBinding["bound"])
}
func TestUserHandlerUnbindIdentityRevokesAllUserSessionsWhenAuthServiceConfigured(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 23,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "identity@example.com",
},
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-23",
},
},
}
refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 23})
c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
handler.UnbindIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, []int64{23}, refreshTokenCache.revokedUserIDs)
require.Equal(t, int64(5), repo.user.TokenVersion)
}
func TestUserHandlerUnbindIdentityDoesNotRevokeSessionsWhenNothingWasUnbound(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 24,
Email: "identity@example.com",
Username: "identity-user",
Role: service.RoleUser,
Status: service.StatusActive,
TokenVersion: 4,
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "email",
ProviderKey: "email",
ProviderSubject: "identity@example.com",
},
},
}
refreshTokenCache := &userHandlerRefreshTokenCacheStub{}
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
}
authService := service.NewAuthService(nil, repo, nil, refreshTokenCache, cfg, nil, nil, nil, nil, nil, nil)
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), authService, nil, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodDelete, "/api/v1/user/account-bindings/linuxdo", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 24})
c.Params = gin.Params{{Key: "provider", Value: "linuxdo"}}
handler.UnbindIdentity(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Empty(t, repo.unbound)
require.Empty(t, refreshTokenCache.revokedUserIDs)
require.Equal(t, int64(4), repo.user.TokenVersion)
}
func TestUserHandlerBindEmailIdentityRejectsWrongCurrentPasswordForBoundEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
......@@ -587,7 +772,7 @@ func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) {
require.Equal(t, "wechat", resp.Data.Provider)
require.Equal(t, "GET", resp.Data.Method)
require.True(t, resp.Data.UseBrowserRedirect)
require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/start")
require.Contains(t, resp.Data.AuthorizeURL, "/api/v1/auth/oauth/wechat/bind/start")
require.Contains(t, resp.Data.AuthorizeURL, "intent=bind_current_user")
require.Contains(t, resp.Data.AuthorizeURL, "redirect=%2Fsettings%2Fprofile")
}
......@@ -60,11 +60,6 @@ const (
wxpayEventTransactionSuccess = "TRANSACTION.SUCCESS"
)
// WeChat Pay error codes.
const (
wxpayErrNoAuth = "NO_AUTH"
)
var (
wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
return svc.Prepay(ctx, req)
......@@ -200,14 +195,7 @@ func (w *Wxpay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequ
case wxpayModeJSAPI:
return w.prepayJSAPI(ctx, client, req, notifyURL, totalFen)
case wxpayModeH5:
resp, err := w.prepayH5(ctx, client, req, notifyURL, totalFen)
if err == nil {
return resp, nil
}
if strings.Contains(err.Error(), wxpayErrNoAuth) {
return nil, fmt.Errorf("wxpay h5 payments are not authorized for this merchant: %w", err)
}
return nil, err
return w.prepayH5(ctx, client, req, notifyURL, totalFen)
case wxpayModeNative:
return w.prepayNative(ctx, client, req, notifyURL, totalFen)
default:
......
......@@ -8,6 +8,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"net/url"
"strings"
"testing"
......@@ -641,3 +642,68 @@ func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL)
}
}
func TestCreatePaymentMobileH5ReturnsNoAuthErrorWithoutNativeFallback(t *testing.T) {
origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
origNativePrepay := wxpayNativePrepay
origH5Prepay := wxpayH5Prepay
t.Cleanup(func() {
wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
wxpayNativePrepay = origNativePrepay
wxpayH5Prepay = origH5Prepay
})
jsapiCalls := 0
nativeCalls := 0
h5Calls := 0
wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
jsapiCalls++
return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
}
wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
h5Calls++
return nil, nil, errors.New("NO_AUTH")
}
wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
nativeCalls++
return &native.PrepayResponse{
CodeUrl: core.String("weixin://wxpay/bizpayurl?pr=fallback-native"),
}, nil, nil
}
provider := &Wxpay{
config: map[string]string{
"appId": "wx123",
"mchId": "mch123",
},
coreClient: &core.Client{},
}
resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_100",
Amount: "66.88",
PaymentType: payment.TypeWxpay,
Subject: "Balance Recharge",
NotifyURL: "https://merchant.example/payment/notify",
ClientIP: "203.0.113.10",
IsMobile: true,
})
if err == nil {
t.Fatal("expected no-auth error, got nil")
}
if jsapiCalls != 0 {
t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
}
if h5Calls != 1 {
t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
}
if nativeCalls != 0 {
t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
}
if resp != nil {
t.Fatalf("expected nil response, got %+v", resp)
}
if !strings.Contains(err.Error(), "NO_AUTH") {
t.Fatalf("error = %v, want NO_AUTH", err)
}
}
......@@ -4,6 +4,7 @@ import (
"encoding/hex"
"fmt"
"log/slog"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
......@@ -19,11 +20,22 @@ type EncryptionKey []byte
// When the key is non-empty but invalid (bad hex or wrong length), an error is returned
// to prevent startup with a misconfigured encryption key.
func ProvideEncryptionKey(cfg *config.Config) (EncryptionKey, error) {
if cfg.Totp.EncryptionKey == "" {
if cfg == nil {
slog.Warn("payment encryption key not configured — encrypted payment config and resume signing will be unavailable")
return nil, nil
}
keyHex := strings.TrimSpace(cfg.Totp.EncryptionKey)
if keyHex == "" {
slog.Warn("payment encryption key not configured — encrypted payment config will be unavailable")
return nil, nil
}
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
// Reject auto-generated TOTP keys for payment signing.
// They change across restarts/instances and can silently break resume-token flows.
if !cfg.Totp.EncryptionKeyConfigured {
slog.Warn("payment encryption/signing key is not explicitly configured; set TOTP_ENCRYPTION_KEY to enable payment resume tokens")
return nil, nil
}
key, err := hex.DecodeString(keyHex)
if err != nil {
return nil, fmt.Errorf("invalid payment encryption key (hex decode): %w", err)
}
......
package payment
import (
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
func TestProvideEncryptionKeySkipsAutoGeneratedTotpKey(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: strings.Repeat("a", 64),
EncryptionKeyConfigured: false,
},
}
key, err := ProvideEncryptionKey(cfg)
if err != nil {
t.Fatalf("ProvideEncryptionKey returned error: %v", err)
}
if len(key) != 0 {
t.Fatalf("encryption key len = %d, want 0", len(key))
}
}
func TestProvideEncryptionKeyUsesConfiguredTotpKey(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
EncryptionKeyConfigured: true,
},
}
key, err := ProvideEncryptionKey(cfg)
if err != nil {
t.Fatalf("ProvideEncryptionKey returned error: %v", err)
}
if len(key) != 32 {
t.Fatalf("encryption key len = %d, want 32", len(key))
}
}
func TestProvideEncryptionKeyRejectsConfiguredInvalidLength(t *testing.T) {
t.Parallel()
cfg := &config.Config{
Totp: config.TotpConfig{
EncryptionKey: "abcd",
EncryptionKeyConfigured: true,
},
}
_, err := ProvideEncryptionKey(cfg)
if err == nil {
t.Fatal("expected error for invalid key length")
}
}
......@@ -4,6 +4,7 @@ package repository
import (
"context"
"database/sql"
"os"
"path/filepath"
"strconv"
......@@ -20,32 +21,8 @@ func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
......@@ -218,32 +195,8 @@ func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectM
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoMalformedUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
......@@ -408,32 +361,8 @@ func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngrades
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
userIDs := make([]int64, 0, 8)
for _, email := range []string{
......@@ -643,6 +572,388 @@ FROM auth_identity_migration_reports
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&afterCount))
`).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount)
}
func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoFirstUserID))
var linuxDoSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoSecondUserID))
var wechatFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatFirstUserID))
var wechatSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatSecondUserID))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
RETURNING id
`, linuxDoFirstUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
RETURNING id
`, linuxDoSecondUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
RETURNING id
`, wechatFirstUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
RETURNING id
`, wechatSecondUserID).Scan(new(int64)))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var linuxDoIdentityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-ambiguous-subject'
`).Scan(&linuxDoIdentityCount))
require.Zero(t, linuxDoIdentityCount)
var wechatIdentityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-ambiguous-subject'
`).Scan(&wechatIdentityCount))
require.Zero(t, wechatIdentityCount)
var wechatChannelCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_channels
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND channel = 'oa'
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
`).Scan(&wechatChannelCount))
require.Zero(t, wechatChannelCount)
}
func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migration115SQL, err := os.ReadFile(migration115Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoFirstUserID))
var linuxDoSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoSecondUserID))
var wechatFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatFirstUserID))
var wechatSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatSecondUserID))
var linuxDoFirstLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
RETURNING id
`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
var linuxDoSecondLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
RETURNING id
`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
var wechatFirstLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
RETURNING id
`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
var wechatSecondLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
RETURNING id
`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
_, err = tx.ExecContext(ctx, string(migration115SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var identityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
`).Scan(&identityCount))
require.Zero(t, identityCount)
var conflictReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
require.Equal(t, 4, conflictReportCount)
var winnerAttributedReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
AND details ->> 'existing_identity_id' IS NOT NULL
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
require.Zero(t, winnerAttributedReportCount)
}
func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration108aPath := filepath.Join("..", "..", "migrations", "108a_widen_auth_identity_migration_report_type.sql")
migration108aSQL, err := os.ReadFile(migration108aPath)
require.NoError(t, err)
migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
migration109SQL, err := os.ReadFile(migration109Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
_, err = tx.ExecContext(ctx, `
ALTER TABLE auth_identity_migration_reports
ALTER COLUMN report_type TYPE VARCHAR(40);
`)
require.NoError(t, err)
var oidcSyntheticUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('oidc-before-121@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&oidcSyntheticUserID))
var linuxdoLegacyUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-before-121@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxdoLegacyUserID))
var invalidMetadataLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-before-121', NULL, 'legacy-linuxdo-before-121', 'Legacy LinuxDo Before 121', '{invalid')
RETURNING id
`, linuxdoLegacyUserID).Scan(&invalidMetadataLegacyID))
_, err = tx.ExecContext(ctx, string(migration108aSQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration109SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var reportTypeWidth int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT character_maximum_length
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
`).Scan(&reportTypeWidth))
require.Equal(t, 80, reportTypeWidth)
var oidcSyntheticRecoveryReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
AND report_key = $1
`, strconv.FormatInt(oidcSyntheticUserID, 10)).Scan(&oidcSyntheticRecoveryReportCount))
require.Equal(t, 1, oidcSyntheticRecoveryReportCount)
var invalidMetadataReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(invalidMetadataLegacyID, 10)).Scan(&invalidMetadataReportCount))
require.Equal(t, 1, invalidMetadataReportCount)
}
func prepareLegacyExternalIdentitiesTable(t *testing.T, tx *sql.Tx, ctx context.Context) {
t.Helper()
_, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`)
require.NoError(t, err)
}
func truncateAuthIdentityLegacyFixtureTables(t *testing.T, tx *sql.Tx, ctx context.Context) {
t.Helper()
_, err := tx.ExecContext(ctx, `
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
pending_auth_sessions,
auth_identities,
auth_identity_migration_reports,
user_provider_default_grants,
user_avatars,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
}
......@@ -51,34 +51,30 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
acceptedChecksums map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
// 规则必须同时匹配「迁移名 + 数据库 checksum + 当前文件 checksum」且两者都落在该迁移的已知版本集合内才会放行,
// 避免放宽全局校验,也允许将误改的历史 migration 回滚为已发布版本而不要求人工修 checksum。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
"054_drop_legacy_cache_columns.sql": {
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
acceptedDBChecksum: map[string]struct{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
},
},
"061_add_usage_log_request_type.sql": {
fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
acceptedDBChecksum: map[string]struct{}{
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
"109_auth_identity_compat_backfill.sql": {
fileChecksum: "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
acceptedDBChecksum: map[string]struct{}{
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3": {},
},
},
"054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
"061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
"109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
"110_pending_auth_and_provider_default_grants.sql": newMigrationChecksumCompatibilityRule("32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279", "e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925"),
"112_add_payment_order_provider_key_snapshot.sql": newMigrationChecksumCompatibilityRule("b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99", "ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e"),
"115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
"116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
"118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227", "a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb"),
"119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
......@@ -205,6 +201,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
if nonTx {
if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
return fmt.Errorf("prepare migration %s: %w", name, err)
}
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
......@@ -254,6 +254,90 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return nil
}
func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
switch name {
case paymentOrdersOutTradeNoUniqueMigration:
return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
default:
return nil
}
}
func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
if err != nil {
return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
}
if len(duplicates) > 0 {
return fmt.Errorf(
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
paymentOrdersOutTradeNoUniqueMigration,
strings.Join(duplicates, ", "),
)
}
invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
if err != nil {
return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
}
if !invalid {
return nil
}
if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
}
return nil
}
func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
rows, err := db.QueryContext(ctx, `
SELECT out_trade_no, COUNT(*) AS duplicate_count
FROM payment_orders
WHERE out_trade_no <> ''
GROUP BY out_trade_no
HAVING COUNT(*) > 1
ORDER BY duplicate_count DESC, out_trade_no
LIMIT 5
`)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
duplicates := make([]string, 0, 5)
for rows.Next() {
var outTradeNo string
var duplicateCount int
if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
return nil, err
}
duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
}
if err := rows.Err(); err != nil {
return nil, err
}
return duplicates, nil
}
func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
var invalid bool
err := db.QueryRowContext(ctx, `
SELECT EXISTS (
SELECT 1
FROM pg_class idx
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
JOIN pg_index i ON i.indexrelid = idx.oid
WHERE ns.nspname = 'public'
AND idx.relname = $1
AND NOT i.indisvalid
)
`, indexName).Scan(&invalid)
return invalid, err
}
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil {
......@@ -328,16 +412,33 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
func checksumSet(values ...string) map[string]struct{} {
out := make(map[string]struct{}, len(values))
for _, value := range values {
out[value] = struct{}{}
}
return out
}
func newMigrationChecksumCompatibilityRule(fileChecksum string, acceptedDBChecksums ...string) migrationChecksumCompatibilityRule {
return migrationChecksumCompatibilityRule{
fileChecksum: fileChecksum,
acceptedDBChecksum: checksumSet(acceptedDBChecksums...),
acceptedChecksums: checksumSet(append([]string{fileChecksum}, acceptedDBChecksums...)...),
}
}
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
if rule.fileChecksum != fileChecksum {
_, dbOK := rule.acceptedChecksums[dbChecksum]
if !dbOK {
return false
}
_, ok = rule.acceptedDBChecksum[dbChecksum]
return ok
_, fileOK := rule.acceptedChecksums[fileChecksum]
return fileOK
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
......
......@@ -55,9 +55,110 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
t.Run("109历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"109_auth_identity_compat_backfill.sql",
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3",
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
)
require.True(t, ok)
})
t.Run("109当前checksum可兼容历史checksum", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"109_auth_identity_compat_backfill.sql",
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
)
require.True(t, ok)
})
t.Run("109回滚到历史文件后仍兼容已应用的新checksum", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"109_auth_identity_compat_backfill.sql",
"0580b4602d85435edf9aca1633db580bb3932f26517f75134106f80275ec2ace",
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
)
require.True(t, ok)
})
t.Run("110历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"110_pending_auth_and_provider_default_grants.sql",
"e3d1f433be2b564cfbdc549adf98fce13c5c7b363ebc20fd05b765d0563b0925",
"32cf87ee787b1bb36b5c691367c96eee37518fa3eed6f3322cf68795e3745279",
)
require.True(t, ok)
})
t.Run("112历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"112_add_payment_order_provider_key_snapshot.sql",
"ffd3e8a2c9295fa9cbefefd629a78268877e5b51bc970a82d9b3f46ec4ebd15e",
"b75f8f56d39455682787696a3d92ad25b055444ca328fb7fca9a460a15d68d99",
)
require.True(t, ok)
})
t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"115_auth_identity_legacy_external_backfill.sql",
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
)
require.True(t, ok)
})
t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"116_auth_identity_legacy_external_safety_reports.sql",
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
)
require.True(t, ok)
})
t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"119_enforce_payment_orders_out_trade_no_unique.sql",
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
"0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e",
)
require.True(t, ok)
})
t.Run("118多个历史checksum都可兼容当前版本", func(t *testing.T) {
for _, dbChecksum := range []string{
"a38243ca0a72c3a01c0a92b7986423054d6133c0399441f853b99802852720fb",
"e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227",
} {
ok := isMigrationChecksumCompatible(
"118_wechat_dual_mode_and_auth_source_defaults.sql",
dbChecksum,
"b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0",
)
require.True(t, ok)
}
})
t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
for _, dbChecksum := range []string{
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
} {
ok := isMigrationChecksumCompatible(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
dbChecksum,
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
)
require.True(t, ok)
}
})
t.Run("119未知checksum不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"119_enforce_payment_orders_out_trade_no_unique.sql",
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34",
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
)
require.False(t, ok)
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment