Commit fb6204ea authored by IanShaw027's avatar IanShaw027
Browse files

feat: apply oauth first-bind defaults and pending bind 2fa

parent 6ea3f42e
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
...@@ -269,6 +270,62 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { ...@@ -269,6 +270,62 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
return return
} }
if session.PendingOAuthBind != nil {
pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
pendingSession, err := pendingSvc.GetBrowserSession(
c.Request.Context(),
session.PendingOAuthBind.PendingSessionToken,
session.PendingOAuthBind.BrowserSessionKey,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(
c.Request.Context(),
h.entClient(),
h.authService,
pendingSession,
decision,
&user.ID,
true,
true,
); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(
c.Request.Context(),
pendingSession.SessionToken,
pendingSession.BrowserSessionKey,
); err != nil {
response.ErrorFrom(c, err)
return
}
secureCookie := isRequestHTTPS(c)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
user, err = h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
}
// Delete the login session (only after all checks pass) // Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
......
...@@ -436,7 +436,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -436,7 +436,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return return
} }
......
...@@ -601,10 +601,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision ...@@ -601,10 +601,12 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision
func applyPendingOAuthBinding( func applyPendingOAuthBinding(
ctx context.Context, ctx context.Context,
client *dbent.Client, client *dbent.Client,
authService *service.AuthService,
session *dbent.PendingAuthSession, session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision, decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64, overrideUserID *int64,
forceBind bool, forceBind bool,
applyFirstBindDefaults bool,
) error { ) error {
if client == nil || session == nil { if client == nil || session == nil {
return nil return nil
...@@ -638,16 +640,17 @@ func applyPendingOAuthBinding( ...@@ -638,16 +640,17 @@ func applyPendingOAuthBinding(
return err return err
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID). if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName). SetUsername(adoptedDisplayName).
Exec(ctx); err != nil { Exec(txCtx); err != nil {
return err return err
} }
} }
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) identity, err := ensurePendingOAuthIdentityForUser(txCtx, tx, session, targetUserID)
if err != nil { if err != nil {
return err return err
} }
...@@ -667,14 +670,20 @@ func applyPendingOAuthBinding( ...@@ -667,14 +670,20 @@ func applyPendingOAuthBinding(
if issuer := oauthIdentityIssuer(session); issuer != nil { if issuer := oauthIdentityIssuer(session); issuer != nil {
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
} }
if _, err := updateIdentity.Save(ctx); err != nil { if _, err := updateIdentity.Save(txCtx); err != nil {
return err return err
} }
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID). SetIdentityID(identity.ID).
Save(ctx); err != nil { Save(txCtx); err != nil {
return err
}
}
if applyFirstBindDefaults && authService != nil {
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(txCtx, targetUserID, session.ProviderType); err != nil {
return err return err
} }
} }
...@@ -685,11 +694,21 @@ func applyPendingOAuthBinding( ...@@ -685,11 +694,21 @@ func applyPendingOAuthBinding(
func applyPendingOAuthAdoption( func applyPendingOAuthAdoption(
ctx context.Context, ctx context.Context,
client *dbent.Client, client *dbent.Client,
authService *service.AuthService,
session *dbent.PendingAuthSession, session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision, decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64, overrideUserID *int64,
) error { ) error {
return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false) return applyPendingOAuthBinding(
ctx,
client,
authService,
session,
decision,
overrideUserID,
false,
strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"),
)
} }
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
...@@ -804,7 +823,26 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { ...@@ -804,7 +823,26 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil { if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession(
c.Request.Context(),
user.ID,
user.Email,
session.SessionToken,
session.BrowserSessionKey,
)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return return
} }
...@@ -900,7 +938,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ...@@ -900,7 +938,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil { if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return return
} }
...@@ -990,7 +1028,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { ...@@ -990,7 +1028,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil { if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return return
} }
......
...@@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { ...@@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return return
} }
......
...@@ -346,7 +346,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { ...@@ -346,7 +346,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return return
} }
......
package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
entsql "entgo.io/ent/dialect/sql"
)
// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap
// settings the first time a user binds a third-party identity. The grant is
// idempotent per user/provider pair.
func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind(
ctx context.Context,
userID int64,
providerType string,
) error {
if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 {
return nil
}
if dbent.TxFromContext(ctx) != nil {
return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType)
}
tx, err := s.entClient.Tx(ctx)
if err != nil {
return fmt.Errorf("begin first bind defaults transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil {
return err
}
return tx.Commit()
}
func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
ctx context.Context,
userID int64,
providerType string,
) error {
defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx)
if err != nil {
return fmt.Errorf("load auth source defaults: %w", err)
}
providerDefaults, ok := authSourceSignupSettings(defaults, providerType)
if !ok || !providerDefaults.GrantOnFirstBind {
return nil
}
client := s.entClient
if tx := dbent.TxFromContext(ctx); tx != nil {
client = tx.Client()
}
var result entsql.Result
if err := client.Driver().Exec(
ctx,
`INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
VALUES (?, ?, ?)
ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
[]any{userID, strings.TrimSpace(providerType), "first_bind"},
&result,
); err != nil {
return fmt.Errorf("record first bind provider grant: %w", err)
}
affected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("read first bind provider grant result: %w", err)
}
if affected == 0 {
return nil
}
if providerDefaults.Balance != 0 {
if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil {
return fmt.Errorf("apply first bind balance default: %w", err)
}
}
if providerDefaults.Concurrency != 0 {
if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil {
return fmt.Errorf("apply first bind concurrency default: %w", err)
}
}
if s.defaultSubAssigner != nil {
for _, item := range providerDefaults.Subscriptions {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by first bind defaults",
}); err != nil {
return fmt.Errorf("apply first bind subscription default: %w", err)
}
}
}
return nil
}
...@@ -58,9 +58,15 @@ type TotpSetupSession struct { ...@@ -58,9 +58,15 @@ type TotpSetupSession struct {
// TotpLoginSession represents a pending 2FA login session // TotpLoginSession represents a pending 2FA login session
type TotpLoginSession struct { type TotpLoginSession struct {
UserID int64 UserID int64
Email string Email string
TokenExpiry time.Time TokenExpiry time.Time
PendingOAuthBind *PendingOAuthBindLoginSession `json:"pending_oauth_bind,omitempty"`
}
type PendingOAuthBindLoginSession struct {
PendingSessionToken string `json:"pending_session_token,omitempty"`
BrowserSessionKey string `json:"browser_session_key,omitempty"`
} }
// TotpStatus represents the TOTP status for a user // TotpStatus represents the TOTP status for a user
...@@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) ...@@ -397,6 +403,30 @@ func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string)
// CreateLoginSession creates a temporary login session for 2FA // CreateLoginSession creates a temporary login session for 2FA
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) { func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
return s.createLoginSession(ctx, userID, email, nil)
}
// CreatePendingOAuthBindLoginSession creates a temporary 2FA session that will
// finalize a pending OAuth bind after the TOTP code is verified.
func (s *TotpService) CreatePendingOAuthBindLoginSession(
ctx context.Context,
userID int64,
email string,
pendingSessionToken string,
browserSessionKey string,
) (string, error) {
return s.createLoginSession(ctx, userID, email, &PendingOAuthBindLoginSession{
PendingSessionToken: pendingSessionToken,
BrowserSessionKey: browserSessionKey,
})
}
func (s *TotpService) createLoginSession(
ctx context.Context,
userID int64,
email string,
pendingOAuthBind *PendingOAuthBindLoginSession,
) (string, error) {
// Generate a random temp token // Generate a random temp token
tempToken, err := generateRandomToken(32) tempToken, err := generateRandomToken(32)
if err != nil { if err != nil {
...@@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai ...@@ -404,9 +434,10 @@ func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, emai
} }
session := &TotpLoginSession{ session := &TotpLoginSession{
UserID: userID, UserID: userID,
Email: email, Email: email,
TokenExpiry: time.Now().Add(totpLoginTTL), TokenExpiry: time.Now().Add(totpLoginTTL),
PendingOAuthBind: pendingOAuthBind,
} }
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil { if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment