"vscode:/vscode.git/clone" did not exist on "a29642599446f89e5962ae36d492a70673c53749"
Commit 7c6491c2 authored by IanShaw027's avatar IanShaw027
Browse files

fix auth pending session hardening

parent 1d8432b8
package handler package handler
import ( import (
"context"
"log/slog" "log/slog"
"strings" "strings"
...@@ -105,6 +106,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { ...@@ -105,6 +106,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
}) })
} }
func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error {
if user == nil {
return infraerrors.Unauthorized("INVALID_USER", "user not found")
}
if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() {
return nil
}
return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
}
func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error {
if h == nil || !h.isBackendModeEnabled(ctx) {
return nil
}
return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
}
func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
if h == nil || h.settingSvc == nil {
return false
}
settings, err := h.settingSvc.GetPublicSettings(ctx)
if err == nil && settings != nil {
return settings.BackendModeEnabled
}
return h.settingSvc.IsBackendModeEnabled(ctx)
}
// Register handles user registration // Register handles user registration
// POST /api/v1/auth/register // POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) { func (h *AuthHandler) Register(c *gin.Context) {
...@@ -178,6 +207,11 @@ func (h *AuthHandler) Login(c *gin.Context) { ...@@ -178,6 +207,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
} }
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成 _ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
// Check if TOTP 2FA is enabled for this user // Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA // Create a temporary login session for 2FA
...@@ -195,11 +229,7 @@ func (h *AuthHandler) Login(c *gin.Context) { ...@@ -195,11 +229,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return return
} }
// Backend mode: only admin can login h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
h.respondWithTokenPair(c, user) h.respondWithTokenPair(c, user)
} }
...@@ -264,9 +294,8 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { ...@@ -264,9 +294,8 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return return
} }
// Backend mode: only admin can login (check BEFORE deleting session) if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() { response.ErrorFrom(c, err)
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return return
} }
...@@ -330,6 +359,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { ...@@ -330,6 +359,10 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
// Delete the login session (only after all checks pass) // Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
if session.PendingOAuthBind == nil {
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
}
h.respondWithTokenPair(c, user) h.respondWithTokenPair(c, user)
} }
......
...@@ -474,6 +474,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -474,6 +474,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail) email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
...@@ -499,6 +507,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -499,6 +507,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
......
...@@ -591,6 +591,58 @@ func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testi ...@@ -591,6 +591,58 @@ func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testi
require.NotNil(t, consumed.ConsumedAt) require.NotNil(t, consumed.ConsumedAt)
} }
func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-invalid-subject-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("linuxdo-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper() t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
......
...@@ -253,6 +253,35 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { ...@@ -253,6 +253,35 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
} }
func pendingOAuthCompletionIncludesTokenPayload(payload map[string]any) bool {
if len(payload) == 0 {
return false
}
for _, key := range []string{"access_token", "refresh_token"} {
if value := pendingSessionStringValue(payload, key); value != "" {
return true
}
}
return false
}
func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
if session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
if strings.TrimSpace(session.Intent) != oauthIntentLogin {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
if session.TargetUserID != nil && *session.TargetUserID > 0 {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload, _ := readCompletionResponse(session.LocalFlowState)
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
return nil
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool { func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
} }
...@@ -1090,6 +1119,10 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { ...@@ -1090,6 +1119,10 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user")) response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
return return
} }
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil { if err != nil {
...@@ -1192,6 +1225,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ...@@ -1192,6 +1225,10 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return return
} }
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount( tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
c.Request.Context(), c.Request.Context(),
...@@ -1215,6 +1252,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ...@@ -1215,6 +1252,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies() clearCookies()
...@@ -1279,6 +1317,25 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { ...@@ -1279,6 +1317,25 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
} }
} }
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
if pendingOAuthCompletionIncludesTokenPayload(payload) {
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
clearCookies()
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
return
}
user, err := h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
}
if pendingSessionWantsInvitation(payload) { if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() { if adoptionDecision.hasDecision() {
......
...@@ -523,6 +523,60 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdopti ...@@ -523,6 +523,60 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdopti
require.NotNil(t, storedSession.ConsumedAt) require.NotNil(t, storedSession.ConsumedAt)
} }
func TestExchangePendingOAuthCompletionBlocksBackendModeBeforeReturningTokenPayload(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
service.SettingKeyBackendModeEnabled: "true",
},
})
ctx := context.Background()
userEntity, err := client.User.Create().
SetEmail("blocked@example.com").
SetUsername("blocked-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("blocked-backend-mode-session-token").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("blocked-subject-123").
SetTargetUserID(userEntity.ID).
SetResolvedEmail(userEntity.Email).
SetBrowserSessionKey("blocked-backend-mode-browser-session-key").
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"access_token": "access-token",
"refresh_token": "refresh-token",
"expires_in": float64(3600),
"token_type": "Bearer",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("blocked-backend-mode-browser-session-key")})
ginCtx.Request = req
handler.ExchangePendingOAuthCompletion(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) { func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecisionWithoutBinding(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true) handler, client := newOAuthPendingFlowTestHandler(t, true)
ctx := context.Background() ctx := context.Background()
...@@ -773,6 +827,60 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te ...@@ -773,6 +827,60 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
} }
func TestCreateOIDCOAuthAccountBlocksBackendModeBeforeCreatingUser(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
emailVerifyEnabled: true,
emailCache: &oauthPendingFlowEmailCacheStub{
verificationCodes: map[string]*service.VerificationCodeData{
"fresh@example.com": {
Code: "246810",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
},
},
settingValues: map[string]string{
service.SettingKeyBackendModeEnabled: "true",
},
})
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("create-account-backend-mode-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-create-backend-mode-123").
SetBrowserSessionKey("create-account-backend-mode-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123"}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-backend-mode-browser-session-key")})
ginCtx.Request = req
handler.CreateOIDCOAuthAccount(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false) handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background() ctx := context.Background()
...@@ -842,6 +950,70 @@ func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { ...@@ -842,6 +950,70 @@ func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) {
require.NotNil(t, storedSession.ConsumedAt) require.NotNil(t, storedSession.ConsumedAt)
} }
func TestBindOIDCOAuthLoginBlocksBackendModeBeforeTokenIssue(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
service.SettingKeyBackendModeEnabled: "true",
},
})
ctx := context.Background()
passwordHash, err := handler.authService.HashPassword("secret-123")
require.NoError(t, err)
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash(passwordHash).
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("bind-login-backend-mode-session-token").
SetIntent("adopt_existing_user_by_email").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("oidc-bind-backend-mode-123").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("bind-login-backend-mode-browser-session-key").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123"}`)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-backend-mode-browser-session-key")})
ginCtx.Request = req
handler.BindOIDCOAuthLogin(ginCtx)
require.Equal(t, http.StatusForbidden, recorder.Code)
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("oidc-bind-backend-mode-123"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false) handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background() ctx := context.Background()
......
...@@ -516,6 +516,14 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { ...@@ -516,6 +516,14 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail) email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
...@@ -541,6 +549,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { ...@@ -541,6 +549,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
......
...@@ -431,6 +431,58 @@ func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing. ...@@ -431,6 +431,58 @@ func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.
require.NotNil(t, consumed.ConsumedAt) require.NotNil(t, consumed.ConsumedAt)
} }
func TestCompleteOIDCOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-invalid-subject-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("oidc-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-invalid-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
type oidcProviderFixture struct { type oidcProviderFixture struct {
Subject string Subject string
PreferredUsername string PreferredUsername string
......
...@@ -506,6 +506,14 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { ...@@ -506,6 +506,14 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
email := strings.TrimSpace(session.ResolvedEmail) email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
...@@ -531,6 +539,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { ...@@ -531,6 +539,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
......
...@@ -846,6 +846,59 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { ...@@ -846,6 +846,59 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) {
require.Equal(t, repairedIdentity.ID, channel.IdentityID) require.Equal(t, repairedIdentity.ID, channel.IdentityID)
} }
func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newWeChatOAuthTestHandler(t, false)
defer client.Close()
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("wechat-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-invalid-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("wechat-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "wechat_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
completeCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("wechat-invalid-browser")})
completeCtx.Request = req
handler.CompleteWeChatOAuthRegistration(completeCtx)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) {
t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app")
t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret")
......
...@@ -104,7 +104,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( ...@@ -104,7 +104,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrServiceUnavailable return nil, nil, ErrServiceUnavailable
} }
s.postAuthUserBootstrap(ctx, user, signupSource, true) s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil { if invitationRedeemCode != nil {
......
...@@ -430,8 +430,6 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string ...@@ -430,8 +430,6 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() { if !user.IsActive() {
return "", nil, ErrUserNotActive return "", nil, ErrUserNotActive
} }
s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
s.touchUserLogin(ctx, user.ID)
// 生成JWT token // 生成JWT token
token, err := s.GenerateToken(user) token, err := s.GenerateToken(user)
...@@ -507,7 +505,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -507,7 +505,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
} }
} else { } else {
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, true) s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
} }
} else { } else {
...@@ -527,8 +525,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -527,8 +525,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
} }
} }
s.touchUserLogin(ctx, user.ID)
token, err := s.GenerateToken(user) token, err := s.GenerateToken(user)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err) return "", nil, fmt.Errorf("generate token: %w", err)
...@@ -634,7 +630,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -634,7 +630,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable return nil, nil, ErrServiceUnavailable
} }
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, true) s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
} }
} else { } else {
...@@ -651,7 +647,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -651,7 +647,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
} }
} else { } else {
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, signupSource, true) s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil { if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
...@@ -676,8 +672,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -676,8 +672,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
} }
} }
s.touchUserLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "") tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err) return nil, nil, fmt.Errorf("generate token pair: %w", err)
......
...@@ -170,24 +170,26 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { ...@@ -170,24 +170,26 @@ func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
require.NotNil(t, identity.VerifiedAt) require.NotNil(t, identity.VerifiedAt)
} }
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { func TestAuthServiceLoginDefersLastLoginTouchUntilRecordSuccessfulLogin(t *testing.T) {
svc, repo, client := newAuthServiceWithEnt(t, map[string]string{ svc, _, client := newAuthServiceWithEnt(t, map[string]string{
service.SettingKeyRegistrationEnabled: "true", service.SettingKeyRegistrationEnabled: "true",
}, nil) }, nil)
ctx := context.Background() ctx := context.Background()
user := &service.User{ passwordHash, err := svc.HashPassword("password")
Email: "login@example.com", require.NoError(t, err)
Role: service.RoleUser, user, err := client.User.Create().
Status: service.StatusActive, SetEmail("login@example.com").
Balance: 1, SetPasswordHash(passwordHash).
Concurrency: 1, SetRole(service.RoleUser).
} SetStatus(service.StatusActive).
require.NoError(t, user.SetPassword("password")) SetBalance(1).
require.NoError(t, repo.Create(ctx, user)) SetConcurrency(1).
Save(ctx)
require.NoError(t, err)
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second) old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
_, err := client.User.UpdateOneID(user.ID). _, err = client.User.UpdateOneID(user.ID).
SetLastLoginAt(old). SetLastLoginAt(old).
SetLastActiveAt(old). SetLastActiveAt(old).
Save(ctx) Save(ctx)
...@@ -202,8 +204,20 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { ...@@ -202,8 +204,20 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, storedUser.LastLoginAt) require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt) require.NotNil(t, storedUser.LastActiveAt)
require.True(t, storedUser.LastLoginAt.After(old)) require.True(t, storedUser.LastLoginAt.Equal(old))
require.True(t, storedUser.LastActiveAt.After(old)) require.True(t, storedUser.LastActiveAt.Equal(old))
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("login@example.com"),
).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
svc.RecordSuccessfulLogin(ctx, user.ID)
identity, err := client.AuthIdentity.Query(). identity, err := client.AuthIdentity.Query().
Where( Where(
...@@ -273,6 +287,7 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe ...@@ -273,6 +287,7 @@ func TestAuthServiceLogin_AppliesEmailFirstBindDefaultsOnlyWhenEmailIdentityIsNe
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, token) require.NotEmpty(t, token)
require.NotNil(t, gotUser) require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID) storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err) require.NoError(t, err)
...@@ -343,6 +358,7 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE ...@@ -343,6 +358,7 @@ func TestAuthServiceLogin_DoesNotApplyEmailFirstBindDefaultsWhenIdentityAlreadyE
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, token) require.NotEmpty(t, token)
require.NotNil(t, gotUser) require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID) storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err) require.NoError(t, err)
...@@ -380,6 +396,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t ...@@ -380,6 +396,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, token) require.NotEmpty(t, token)
require.NotNil(t, gotUser) require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err := client.User.Get(ctx, user.ID) storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err) require.NoError(t, err)
...@@ -392,6 +409,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t ...@@ -392,6 +409,7 @@ func TestAuthServiceLogin_RetriesEmailFirstBindDefaultsAfterPreviousFailure(t *t
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, token) require.NotEmpty(t, token)
require.NotNil(t, gotUser) require.NotNil(t, gotUser)
svc.RecordSuccessfulLogin(ctx, user.ID)
storedUser, err = client.User.Get(ctx, user.ID) storedUser, err = client.User.Get(ctx, user.ID)
require.NoError(t, err) require.NoError(t, err)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment