Commit 36aed359 authored by IanShaw027's avatar IanShaw027
Browse files

fix(auth): harden oauth identity upgrade paths

parent 3d29f7c2
...@@ -3,7 +3,10 @@ package repository ...@@ -3,7 +3,10 @@ package repository
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"sync"
"testing" "testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/enttest"
...@@ -18,9 +21,10 @@ import ( ...@@ -18,9 +21,10 @@ import (
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
t.Helper() t.Helper()
db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared") db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() }) t.Cleanup(func() { _ = db.Close() })
db.SetMaxOpenConns(10)
_, err = db.Exec("PRAGMA foreign_keys = ON") _, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err) require.NoError(t, err)
...@@ -144,3 +148,80 @@ func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) { ...@@ -144,3 +148,80 @@ func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.ErrorContains(t, err, "normalized email lookup matched multiple users") require.ErrorContains(t, err, "normalized email lookup matched multiple users")
} }
func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.User.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type createResult struct {
err error
}
results := make(chan createResult, 2)
go func() {
results <- createResult{err: repo.Create(ctx, &service.User{
Email: " Race@Example.com ",
Username: "race-user-1",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})}
}()
<-firstCreateStarted
go func() {
results <- createResult{err: repo.Create(ctx, &service.User{
Email: "race@example.com",
Username: "race-user-2",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
errors := []error{first.err, second.err}
successes := 0
conflicts := 0
for _, err := range errors {
switch {
case err == nil:
successes++
case err == service.ErrEmailExists:
conflicts++
default:
t.Fatalf("unexpected create error: %v", err)
}
}
require.Equal(t, 1, successes)
require.Equal(t, 1, conflicts)
count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
}
...@@ -14,10 +14,14 @@ import ( ...@@ -14,10 +14,14 @@ import (
func normalizeOAuthSignupSource(signupSource string) string { func normalizeOAuthSignupSource(signupSource string) string {
signupSource = strings.TrimSpace(strings.ToLower(signupSource)) signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" { switch signupSource {
case "", "email":
return "email" return "email"
} case "linuxdo", "wechat", "oidc":
return signupSource return signupSource
default:
return "email"
}
} }
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth // SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
...@@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( ...@@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, fmt.Errorf("hash password: %w", err) return nil, nil, fmt.Errorf("hash password: %w", err)
} }
signupSource = strings.TrimSpace(strings.ToLower(signupSource)) signupSource = normalizeOAuthSignupSource(signupSource)
if signupSource == "" {
signupSource = "email"
}
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
user := &User{ user := &User{
...@@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( ...@@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
Status: StatusActive, Status: StatusActive,
SignupSource: signupSource,
} }
if err := s.userRepo.Create(ctx, user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
......
...@@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai ...@@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
require.Empty(t, redeemRepo.updateCalls) require.Empty(t, redeemRepo.updateCalls)
} }
func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
userRepo := &userRepoStub{nextID: 42}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fresh@example.com",
"secret-123",
"246810",
"",
" OIDC ",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
}
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fallback@example.com",
"secret-123",
"246810",
"",
"github",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "email", userRepo.created[0].SignupSource)
}
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) { func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
userRepo := &userRepoStub{} userRepo := &userRepoStub{}
redeemRepo := &redeemCodeRepoStub{ redeemRepo := &redeemCodeRepoStub{
......
...@@ -5,10 +5,15 @@ import ( ...@@ -5,10 +5,15 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"hash/fnv"
"sort"
"strings" "strings"
"sync"
"time" "time"
"entgo.io/ent/dialect"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
...@@ -75,6 +80,122 @@ type AuthPendingIdentityService struct { ...@@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
entClient *dbent.Client entClient *dbent.Client
} }
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
type authPendingIdentityScopedKeyLockRegistry struct {
mu sync.Mutex
locks map[string]*authPendingIdentityScopedKeyLockEntry
}
type authPendingIdentityScopedKeyLockEntry struct {
mu sync.Mutex
refs int
}
func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
return &authPendingIdentityScopedKeyLockRegistry{
locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
}
}
func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 {
return func() {}
}
entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
r.mu.Lock()
for _, key := range normalized {
entry := r.locks[key]
if entry == nil {
entry = &authPendingIdentityScopedKeyLockEntry{}
r.locks[key] = entry
}
entry.refs++
entries = append(entries, entry)
}
r.mu.Unlock()
for _, entry := range entries {
entry.mu.Lock()
}
return func() {
for i := len(entries) - 1; i >= 0; i-- {
entries[i].mu.Unlock()
}
r.mu.Lock()
defer r.mu.Unlock()
for idx, key := range normalized {
entry := entries[idx]
entry.refs--
if entry.refs == 0 {
delete(r.locks, key)
}
}
}
}
func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
if len(keys) == 0 {
return nil
}
deduped := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
deduped[trimmed] = struct{}{}
}
if len(deduped) == 0 {
return nil
}
normalized := make([]string, 0, len(deduped))
for key := range deduped {
normalized = append(normalized, key)
}
sort.Strings(normalized)
return normalized
}
func authPendingIdentityAdvisoryLockHash(key string) int64 {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(key))
return int64(hasher.Sum64())
}
func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
release := authPendingIdentityScopedKeyLocks.lock(keys...)
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
return release, nil
}
for _, key := range normalized {
var rows entsql.Rows
if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
release()
return nil, err
}
_ = rows.Close()
}
return release, nil
}
func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
if identityID != nil && *identityID > 0 {
keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
}
return keys
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient} return &AuthPendingIdentityService{entClient: entClient}
} }
...@@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, ...@@ -324,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
return nil, fmt.Errorf("pending auth ent client is not configured") return nil, fmt.Errorf("pending auth ent client is not configured")
} }
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
client := s.entClient
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
client = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
client = existingTx.Client()
}
releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
if err != nil {
return nil, err
}
defer releaseLocks()
if input.IdentityID != nil && *input.IdentityID > 0 { if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := s.entClient.IdentityAdoptionDecision.Update(). if _, err := client.IdentityAdoptionDecision.Update().
Where( Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID), identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
...@@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, ...@@ -337,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
}), }),
). ).
ClearIdentityID(). ClearIdentityID().
Save(ctx); err != nil { Save(txCtx); err != nil {
return nil, err return nil, err
} }
} }
existing, err := s.entClient.IdentityAdoptionDecision.Query(). create := client.IdentityAdoptionDecision.Create().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID). SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName). SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar). SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC()) SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil { if input.IdentityID != nil && *input.IdentityID > 0 {
create = create.SetIdentityID(*input.IdentityID) create = create.SetIdentityID(*input.IdentityID)
} }
return create.Save(ctx)
decisionID, err := create.
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
UpdateNewValues().
ID(txCtx)
if err != nil {
return nil, err
} }
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID). decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
SetAdoptDisplayName(input.AdoptDisplayName). if err != nil {
SetAdoptAvatar(input.AdoptAvatar) return nil, err
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
} }
return update.Save(ctx)
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}
return decision, nil
} }
func copyPendingMap(in map[string]any) map[string]any { func copyPendingMap(in map[string]any) map[string]any {
......
...@@ -5,6 +5,7 @@ package service ...@@ -5,6 +5,7 @@ package service
import ( import (
"context" "context"
"database/sql" "database/sql"
"sync"
"testing" "testing"
"time" "time"
...@@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden ...@@ -259,6 +260,107 @@ func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIden
require.Nil(t, reloadedFirst.IdentityID) require.Nil(t, reloadedFirst.IdentityID)
} }
func TestAuthPendingIdentityService_UpsertAdoptionDecision_IsIdempotentUnderConcurrency(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption-concurrent@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-concurrent").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-concurrent",
},
})
require.NoError(t, err)
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type adoptionResult struct {
decision *dbent.IdentityAdoptionDecision
err error
}
input := PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
}
results := make(chan adoptionResult, 2)
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
<-firstCreateStarted
go func() {
decision, err := svc.UpsertAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
require.NoError(t, first.err)
require.NoError(t, second.err)
require.NotNil(t, first.decision)
require.NotNil(t, second.decision)
require.Equal(t, first.decision.ID, second.decision.ID)
count, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
loaded, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, loaded.IdentityID)
require.Equal(t, identity.ID, *loaded.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) { func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL") t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
...@@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -489,6 +490,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
Status: StatusActive, Status: StatusActive,
SignupSource: signupSource,
} }
if err := s.userRepo.Create(ctx, newUser); err != nil { if err := s.userRepo.Create(ctx, newUser); err != nil {
...@@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -599,6 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
Balance: grantPlan.Balance, Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency, Concurrency: grantPlan.Concurrency,
Status: StatusActive, Status: StatusActive,
SignupSource: signupSource,
} }
if s.entClient != nil && invitationRedeemCode != nil { if s.entClient != nil && invitationRedeemCode != nil {
...@@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) { ...@@ -1048,7 +1051,7 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
UserID: user.ID, UserID: user.ID,
Email: user.Email, Email: user.Email,
Role: user.Role, Role: user.Role,
TokenVersion: user.TokenVersion, TokenVersion: resolvedTokenVersion(user),
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt), ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now),
...@@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( ...@@ -1114,7 +1117,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
// Security: Check TokenVersion to prevent refreshing revoked tokens // Security: Check TokenVersion to prevent refreshing revoked tokens
// This ensures tokens issued before a password change cannot be refreshed // This ensures tokens issued before a password change cannot be refreshed
if claims.TokenVersion != user.TokenVersion { if claims.TokenVersion != resolvedTokenVersion(user) {
return "", ErrTokenRevoked return "", ErrTokenRevoked
} }
...@@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami ...@@ -1342,7 +1345,7 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami
data := &RefreshTokenData{ data := &RefreshTokenData{
UserID: user.ID, UserID: user.ID,
TokenVersion: user.TokenVersion, TokenVersion: resolvedTokenVersion(user),
FamilyID: familyID, FamilyID: familyID,
CreatedAt: now, CreatedAt: now,
ExpiresAt: now.Add(ttl), ExpiresAt: now.Add(ttl),
...@@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) ...@@ -1422,7 +1425,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string)
} }
// 检查TokenVersion(密码更改后所有Token失效) // 检查TokenVersion(密码更改后所有Token失效)
if data.TokenVersion != user.TokenVersion { if data.TokenVersion != resolvedTokenVersion(user) {
// TokenVersion不匹配,撤销整个Token家族 // TokenVersion不匹配,撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked return nil, ErrTokenRevoked
...@@ -1492,3 +1495,14 @@ func hashToken(token string) string { ...@@ -1492,3 +1495,14 @@ func hashToken(token string) string {
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
} }
func resolvedTokenVersion(user *User) int64 {
if user == nil {
return 0
}
material := strings.ToLower(strings.TrimSpace(user.Email)) + "\n" + user.PasswordHash
sum := sha256.Sum256([]byte(material))
fingerprint := int64(binary.BigEndian.Uint64(sum[:8]) & 0x7fffffffffffffff)
return user.TokenVersion ^ fingerprint
}
...@@ -814,6 +814,20 @@ func parseCustomMenuItemURLs(raw string) []string { ...@@ -814,6 +814,20 @@ func parseCustomMenuItemURLs(raw string) []string {
return urls return urls
} }
func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.UsePKCEExplicit {
return base.UsePKCE
}
return false
}
func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool {
if base.ValidateIDTokenExplicit {
return base.ValidateIDToken
}
return false
}
// UpdateSettings 更新系统设置 // UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
updates, err := s.buildSystemSettingsUpdates(ctx, settings) updates, err := s.buildSystemSettingsUpdates(ctx, settings)
...@@ -1479,6 +1493,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -1479,6 +1493,17 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
return fmt.Errorf("check existing settings: %w", err) return fmt.Errorf("check existing settings: %w", err)
} }
oidcUsePKCEDefault := true
oidcValidateIDTokenDefault := true
if s != nil && s.cfg != nil {
if s.cfg.OIDC.UsePKCEExplicit {
oidcUsePKCEDefault = s.cfg.OIDC.UsePKCE
}
if s.cfg.OIDC.ValidateIDTokenExplicit {
oidcValidateIDTokenDefault = s.cfg.OIDC.ValidateIDToken
}
}
// 初始化默认设置 // 初始化默认设置
defaults := map[string]string{ defaults := map[string]string{
SettingKeyRegistrationEnabled: "true", SettingKeyRegistrationEnabled: "true",
...@@ -1523,8 +1548,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { ...@@ -1523,8 +1548,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOIDCConnectRedirectURL: "", SettingKeyOIDCConnectRedirectURL: "",
SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
SettingKeyOIDCConnectUsePKCE: "true", SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
SettingKeyOIDCConnectValidateIDToken: "true", SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
SettingKeyOIDCConnectClockSkewSeconds: "120", SettingKeyOIDCConnectClockSkewSeconds: "120",
SettingKeyOIDCConnectRequireEmailVerified: "false", SettingKeyOIDCConnectRequireEmailVerified: "false",
...@@ -1767,12 +1792,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ...@@ -1767,12 +1792,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
result.OIDCConnectUsePKCE = raw == "true" result.OIDCConnectUsePKCE = raw == "true"
} else { } else {
result.OIDCConnectUsePKCE = oidcBase.UsePKCE result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase)
} }
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
result.OIDCConnectValidateIDToken = raw == "true" result.OIDCConnectValidateIDToken = raw == "true"
} else { } else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase)
} }
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
...@@ -2482,9 +2507,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. ...@@ -2482,9 +2507,13 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
} }
if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok {
effective.UsePKCE = raw == "true" effective.UsePKCE = raw == "true"
} else {
effective.UsePKCE = oidcUsePKCECompatibilityDefault(effective)
} }
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true" effective.ValidateIDToken = raw == "true"
} else {
effective.ValidateIDToken = oidcValidateIDTokenCompatibilityDefault(effective)
} }
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v) effective.AllowedSigningAlgs = strings.TrimSpace(v)
......
...@@ -119,7 +119,9 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue ...@@ -119,7 +119,9 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{ svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{ OIDC: config.OIDCConnectConfig{
UsePKCE: true, UsePKCE: true,
UsePKCEExplicit: true,
ValidateIDToken: true, ValidateIDToken: true,
ValidateIDTokenExplicit: true,
}, },
}) })
...@@ -131,6 +133,22 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue ...@@ -131,6 +133,22 @@ func TestSettingService_ParseSettings_DefaultsOIDCSecurityFlagsToSafeConfigValue
require.True(t, got.OIDCConnectValidateIDToken) require.True(t, got.OIDCConnectValidateIDToken)
} }
func TestSettingService_ParseSettings_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
svc := NewSettingService(&settingOIDCRepoStub{values: map[string]string{}}, &config.Config{
OIDC: config.OIDCConnectConfig{
UsePKCE: true,
ValidateIDToken: true,
},
})
got := svc.parseSettings(map[string]string{
SettingKeyOIDCConnectEnabled: "true",
})
require.False(t, got.OIDCConnectUsePKCE)
require.False(t, got.OIDCConnectValidateIDToken)
}
func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) { func TestGetOIDCConnectOAuthConfig_AllowsCompatibilityFlagsToDisablePKCEAndIDTokenValidation(t *testing.T) {
cfg := &config.Config{ cfg := &config.Config{
OIDC: config.OIDCConnectConfig{ OIDC: config.OIDCConnectConfig{
...@@ -179,7 +197,9 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t ...@@ -179,7 +197,9 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
Scopes: "openid email profile", Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post", TokenAuthMethod: "client_secret_post",
UsePKCE: true, UsePKCE: true,
UsePKCEExplicit: true,
ValidateIDToken: true, ValidateIDToken: true,
ValidateIDTokenExplicit: true,
AllowedSigningAlgs: "RS256", AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120, ClockSkewSeconds: 120,
}, },
...@@ -195,3 +215,37 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t ...@@ -195,3 +215,37 @@ func TestGetOIDCConnectOAuthConfig_DefaultsToSecureFlagsWhenSettingsMissing(t *t
require.True(t, got.UsePKCE) require.True(t, got.UsePKCE)
require.True(t, got.ValidateIDToken) require.True(t, got.ValidateIDToken)
} }
func TestGetOIDCConnectOAuthConfig_UsesLegacyOIDCCompatibilityFlagsWhenSettingsMissing(t *testing.T) {
cfg := &config.Config{
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
}
repo := &settingOIDCRepoStub{values: map[string]string{
SettingKeyOIDCConnectEnabled: "true",
}}
svc := NewSettingService(repo, cfg)
got, err := svc.GetOIDCConnectOAuthConfig(context.Background())
require.NoError(t, err)
require.False(t, got.UsePKCE)
require.False(t, got.ValidateIDToken)
}
...@@ -38,23 +38,22 @@ VALUES ...@@ -38,23 +38,22 @@ VALUES
('auth_source_default_email_balance', '0'), ('auth_source_default_email_balance', '0'),
('auth_source_default_email_concurrency', '5'), ('auth_source_default_email_concurrency', '5'),
('auth_source_default_email_subscriptions', '[]'), ('auth_source_default_email_subscriptions', '[]'),
('auth_source_default_email_grant_on_signup', 'true'), ('auth_source_default_email_grant_on_signup', 'false'),
('auth_source_default_email_grant_on_first_bind', 'false'), ('auth_source_default_email_grant_on_first_bind', 'false'),
('auth_source_default_linuxdo_balance', '0'), ('auth_source_default_linuxdo_balance', '0'),
('auth_source_default_linuxdo_concurrency', '5'), ('auth_source_default_linuxdo_concurrency', '5'),
('auth_source_default_linuxdo_subscriptions', '[]'), ('auth_source_default_linuxdo_subscriptions', '[]'),
('auth_source_default_linuxdo_grant_on_signup', 'true'), ('auth_source_default_linuxdo_grant_on_signup', 'false'),
('auth_source_default_linuxdo_grant_on_first_bind', 'false'), ('auth_source_default_linuxdo_grant_on_first_bind', 'false'),
('auth_source_default_oidc_balance', '0'), ('auth_source_default_oidc_balance', '0'),
('auth_source_default_oidc_concurrency', '5'), ('auth_source_default_oidc_concurrency', '5'),
('auth_source_default_oidc_subscriptions', '[]'), ('auth_source_default_oidc_subscriptions', '[]'),
('auth_source_default_oidc_grant_on_signup', 'true'), ('auth_source_default_oidc_grant_on_signup', 'false'),
('auth_source_default_oidc_grant_on_first_bind', 'false'), ('auth_source_default_oidc_grant_on_first_bind', 'false'),
('auth_source_default_wechat_balance', '0'), ('auth_source_default_wechat_balance', '0'),
('auth_source_default_wechat_concurrency', '5'), ('auth_source_default_wechat_concurrency', '5'),
('auth_source_default_wechat_subscriptions', '[]'), ('auth_source_default_wechat_subscriptions', '[]'),
('auth_source_default_wechat_grant_on_signup', 'true'), ('auth_source_default_wechat_grant_on_signup', 'false'),
('auth_source_default_wechat_grant_on_first_bind', 'false'), ('auth_source_default_wechat_grant_on_first_bind', 'false'),
('force_email_on_third_party_signup', 'false') ('force_email_on_third_party_signup', 'false')
ON CONFLICT (key) DO NOTHING; ON CONFLICT (key) DO NOTHING;
...@@ -31,6 +31,41 @@ BEGIN ...@@ -31,6 +31,41 @@ BEGIN
END IF; END IF;
EXECUTE $sql$ EXECUTE $sql$
WITH legacy AS (
SELECT
uei.id,
uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo'
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
),
legacy_subjects AS (
SELECT
provider_user_id AS provider_subject,
COUNT(DISTINCT user_id) AS distinct_user_count
FROM legacy
GROUP BY provider_user_id
),
canonical_legacy AS (
SELECT
legacy.*,
ROW_NUMBER() OVER (
PARTITION BY legacy.provider_user_id
ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
) AS canonical_row_num
FROM legacy
JOIN legacy_subjects AS subjects
ON subjects.provider_subject = legacy.provider_user_id
AND subjects.distinct_user_count = 1
)
INSERT INTO auth_identities ( INSERT INTO auth_identities (
user_id, user_id,
provider_type, provider_type,
...@@ -52,11 +87,18 @@ SELECT ...@@ -52,11 +87,18 @@ SELECT
'display_name', legacy.display_name, 'display_name', legacy.display_name,
'migration', '115_auth_identity_legacy_external_backfill' 'migration', '115_auth_identity_legacy_external_backfill'
) )
FROM ( FROM canonical_legacy AS legacy
WHERE legacy.canonical_row_num = 1
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
$sql$;
EXECUTE $sql$
WITH legacy AS (
SELECT SELECT
uei.id, uei.id,
uei.user_id, uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id, BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(uei.provider_username) AS provider_username, BTRIM(uei.provider_username) AS provider_username,
BTRIM(uei.display_name) AS display_name, BTRIM(uei.display_name) AS display_name,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
...@@ -65,13 +107,28 @@ FROM ( ...@@ -65,13 +107,28 @@ FROM (
FROM user_external_identities AS uei FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
) AS legacy ),
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; legacy_subjects AS (
$sql$; SELECT
provider_union_id AS provider_subject,
EXECUTE $sql$ COUNT(DISTINCT user_id) AS distinct_user_count
FROM legacy
GROUP BY provider_union_id
),
canonical_legacy AS (
SELECT
legacy.*,
ROW_NUMBER() OVER (
PARTITION BY legacy.provider_union_id
ORDER BY COALESCE(legacy.updated_at, legacy.created_at, NOW()) DESC, legacy.id DESC
) AS canonical_row_num
FROM legacy
JOIN legacy_subjects AS subjects
ON subjects.provider_subject = legacy.provider_union_id
AND subjects.distinct_user_count = 1
)
INSERT INTO auth_identities ( INSERT INTO auth_identities (
user_id, user_id,
provider_type, provider_type,
...@@ -96,27 +153,36 @@ SELECT ...@@ -96,27 +153,36 @@ SELECT
'display_name', legacy.display_name, 'display_name', legacy.display_name,
'migration', '115_auth_identity_legacy_external_backfill' 'migration', '115_auth_identity_legacy_external_backfill'
) )
FROM ( FROM canonical_legacy AS legacy
WHERE legacy.canonical_row_num = 1
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
$sql$;
EXECUTE $sql$
WITH legacy AS (
SELECT SELECT
uei.id,
uei.user_id, uei.user_id,
BTRIM(uei.provider_user_id) AS provider_user_id, BTRIM(uei.provider_user_id) AS provider_user_id,
BTRIM(uei.provider_union_id) AS provider_union_id, BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(uei.provider_username) AS provider_username, BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
BTRIM(uei.display_name) AS display_name, BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json, meta.metadata_json
uei.created_at,
uei.updated_at
FROM user_external_identities AS uei FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id JOIN users AS u ON u.id = uei.user_id
CROSS JOIN LATERAL (
SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
) AS meta
WHERE u.deleted_at IS NULL WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
) AS legacy ),
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; legacy_subjects AS (
$sql$; SELECT
provider_union_id AS provider_subject,
EXECUTE $sql$ COUNT(DISTINCT user_id) AS distinct_user_count
FROM legacy
GROUP BY provider_union_id
)
INSERT INTO auth_identity_channels ( INSERT INTO auth_identity_channels (
identity_id, identity_id,
provider_type, provider_type,
...@@ -138,23 +204,10 @@ SELECT ...@@ -138,23 +204,10 @@ SELECT
'unionid', legacy.provider_union_id, 'unionid', legacy.provider_union_id,
'migration', '115_auth_identity_legacy_external_backfill' 'migration', '115_auth_identity_legacy_external_backfill'
) )
FROM ( FROM legacy
SELECT JOIN legacy_subjects AS subjects
uei.user_id, ON subjects.provider_subject = legacy.provider_union_id
BTRIM(uei.provider_user_id) AS provider_user_id, AND subjects.distinct_user_count = 1
BTRIM(uei.provider_union_id) AS provider_union_id,
BTRIM(COALESCE(meta.metadata_json ->> 'channel', '')) AS channel,
BTRIM(COALESCE(meta.metadata_json ->> 'channel_app_id', meta.metadata_json ->> 'appid', meta.metadata_json ->> 'app_id', '')) AS channel_app_id,
meta.metadata_json
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
CROSS JOIN LATERAL (
SELECT public.__migration_115_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
) AS meta
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
) AS legacy
JOIN auth_identities AS ai JOIN auth_identities AS ai
ON ai.user_id = legacy.user_id ON ai.user_id = legacy.user_id
AND ai.provider_type = 'wechat' AND ai.provider_type = 'wechat'
......
...@@ -74,6 +74,82 @@ $sql$; ...@@ -74,6 +74,82 @@ $sql$;
EXECUTE $sql$ EXECUTE $sql$
INSERT INTO auth_identity_migration_reports (report_type, report_key, details) INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'legacy_external_identity_conflict',
'legacy_external_identity:' || legacy.id::text,
legacy.metadata_json || jsonb_build_object(
'legacy_identity_id', legacy.id,
'legacy_user_id', legacy.user_id,
'provider_type', legacy.provider_type,
'provider_key', legacy.provider_key,
'provider_subject', legacy.provider_subject,
'conflicting_legacy_user_ids', ambiguous.conflicting_legacy_user_ids,
'reason', 'legacy canonical identity subject belongs to multiple legacy users and cannot be auto-resolved',
'migration', '116_auth_identity_legacy_external_safety_reports'
)
FROM (
SELECT
uei.id,
uei.user_id,
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
ELSE 'linuxdo'
END AS provider_key,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
END AS provider_subject,
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
AND (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy
JOIN (
SELECT
provider_type,
provider_key,
provider_subject,
to_jsonb(array_agg(DISTINCT user_id ORDER BY user_id)) AS conflicting_legacy_user_ids
FROM (
SELECT
uei.user_id,
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
ELSE 'linuxdo'
END AS provider_key,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
END AS provider_subject
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
AND (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy_subjects
GROUP BY provider_type, provider_key, provider_subject
HAVING COUNT(DISTINCT user_id) > 1
) AS ambiguous
ON ambiguous.provider_type = legacy.provider_type
AND ambiguous.provider_key = legacy.provider_key
AND ambiguous.provider_subject = legacy.provider_subject
ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$;
EXECUTE $sql$
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT SELECT
'legacy_external_identity_conflict', 'legacy_external_identity_conflict',
'legacy_external_identity:' || legacy.id::text, 'legacy_external_identity:' || legacy.id::text,
...@@ -116,6 +192,39 @@ FROM ( ...@@ -116,6 +192,39 @@ FROM (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
) )
) AS legacy ) AS legacy
JOIN (
SELECT
provider_type,
provider_key,
provider_subject
FROM (
SELECT
uei.user_id,
LOWER(BTRIM(COALESCE(uei.provider, ''))) AS provider_type,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN 'wechat-main'
ELSE 'linuxdo'
END AS provider_key,
CASE
WHEN LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' THEN BTRIM(COALESCE(uei.provider_union_id, ''))
ELSE BTRIM(COALESCE(uei.provider_user_id, ''))
END AS provider_subject
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) IN ('linuxdo', 'wechat')
AND (
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'linuxdo' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '')
OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
)
) AS legacy_subjects
GROUP BY provider_type, provider_key, provider_subject
HAVING COUNT(DISTINCT user_id) = 1
) AS clear_subjects
ON clear_subjects.provider_type = legacy.provider_type
AND clear_subjects.provider_key = legacy.provider_key
AND clear_subjects.provider_subject = legacy.provider_subject
JOIN auth_identities AS ai JOIN auth_identities AS ai
ON ai.provider_type = legacy.provider_type ON ai.provider_type = legacy.provider_type
AND ai.provider_key = legacy.provider_key AND ai.provider_key = legacy.provider_key
...@@ -125,29 +234,7 @@ ON CONFLICT (report_type, report_key) DO NOTHING; ...@@ -125,29 +234,7 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$; $sql$;
EXECUTE $sql$ EXECUTE $sql$
INSERT INTO auth_identities ( WITH legacy AS (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
legacy.user_id,
legacy.provider_type,
legacy.provider_key,
legacy.provider_subject,
legacy.verified_at,
legacy.metadata_json || jsonb_build_object(
'legacy_identity_id', legacy.id,
'provider_user_id', legacy.provider_user_id,
'provider_union_id', NULLIF(legacy.provider_union_id, ''),
'provider_username', legacy.provider_username,
'display_name', legacy.display_name,
'migration', '116_auth_identity_legacy_external_safety_reports'
)
FROM (
SELECT SELECT
uei.id, uei.id,
uei.user_id, uei.user_id,
...@@ -175,12 +262,58 @@ FROM ( ...@@ -175,12 +262,58 @@ FROM (
OR OR
(LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '') (LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '')
) )
) AS legacy ),
clear_subjects AS (
SELECT
provider_type,
provider_key,
provider_subject
FROM legacy
GROUP BY provider_type, provider_key, provider_subject
HAVING COUNT(DISTINCT user_id) = 1
),
canonical_legacy AS (
SELECT
legacy.*,
ROW_NUMBER() OVER (
PARTITION BY legacy.provider_type, legacy.provider_key, legacy.provider_subject
ORDER BY legacy.verified_at DESC, legacy.id DESC
) AS canonical_row_num
FROM legacy
JOIN clear_subjects
ON clear_subjects.provider_type = legacy.provider_type
AND clear_subjects.provider_key = legacy.provider_key
AND clear_subjects.provider_subject = legacy.provider_subject
)
INSERT INTO auth_identities (
user_id,
provider_type,
provider_key,
provider_subject,
verified_at,
metadata
)
SELECT
legacy.user_id,
legacy.provider_type,
legacy.provider_key,
legacy.provider_subject,
legacy.verified_at,
legacy.metadata_json || jsonb_build_object(
'legacy_identity_id', legacy.id,
'provider_user_id', legacy.provider_user_id,
'provider_union_id', NULLIF(legacy.provider_union_id, ''),
'provider_username', legacy.provider_username,
'display_name', legacy.display_name,
'migration', '116_auth_identity_legacy_external_safety_reports'
)
FROM canonical_legacy AS legacy
LEFT JOIN auth_identities AS ai LEFT JOIN auth_identities AS ai
ON ai.provider_type = legacy.provider_type ON ai.provider_type = legacy.provider_type
AND ai.provider_key = legacy.provider_key AND ai.provider_key = legacy.provider_key
AND ai.provider_subject = legacy.provider_subject AND ai.provider_subject = legacy.provider_subject
WHERE ai.id IS NULL WHERE legacy.canonical_row_num = 1
AND ai.id IS NULL
ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING;
$sql$; $sql$;
...@@ -225,6 +358,19 @@ FROM ( ...@@ -225,6 +358,19 @@ FROM (
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> '' AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> '' AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
) AS legacy ) AS legacy
JOIN (
SELECT
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_subject
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
GROUP BY BTRIM(COALESCE(uei.provider_union_id, ''))
HAVING COUNT(DISTINCT uei.user_id) = 1
) AS clear_subjects
ON clear_subjects.provider_subject = legacy.provider_union_id
JOIN auth_identities AS legacy_ai JOIN auth_identities AS legacy_ai
ON legacy_ai.user_id = legacy.user_id ON legacy_ai.user_id = legacy.user_id
AND legacy_ai.provider_type = 'wechat' AND legacy_ai.provider_type = 'wechat'
...@@ -245,6 +391,33 @@ ON CONFLICT (report_type, report_key) DO NOTHING; ...@@ -245,6 +391,33 @@ ON CONFLICT (report_type, report_key) DO NOTHING;
$sql$; $sql$;
EXECUTE $sql$ EXECUTE $sql$
WITH legacy AS (
SELECT
uei.user_id,
BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
BTRIM(COALESCE(
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
''
)) AS channel_app_id
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
),
clear_subjects AS (
SELECT
provider_union_id AS provider_subject
FROM legacy
GROUP BY provider_union_id
HAVING COUNT(DISTINCT user_id) = 1
)
INSERT INTO auth_identity_channels ( INSERT INTO auth_identity_channels (
identity_id, identity_id,
provider_type, provider_type,
...@@ -266,26 +439,9 @@ SELECT ...@@ -266,26 +439,9 @@ SELECT
'unionid', legacy.provider_union_id, 'unionid', legacy.provider_union_id,
'migration', '116_auth_identity_legacy_external_safety_reports' 'migration', '116_auth_identity_legacy_external_safety_reports'
) )
FROM ( FROM legacy
SELECT JOIN clear_subjects
uei.user_id, ON clear_subjects.provider_subject = legacy.provider_union_id
BTRIM(COALESCE(uei.provider_user_id, '')) AS provider_user_id,
BTRIM(COALESCE(uei.provider_union_id, '')) AS provider_union_id,
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) AS metadata_json,
BTRIM(COALESCE(public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel', '')) AS channel,
BTRIM(COALESCE(
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'channel_app_id',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'appid',
public.__migration_116_safe_legacy_metadata_jsonb(uei.metadata) ->> 'app_id',
''
)) AS channel_app_id
FROM user_external_identities AS uei
JOIN users AS u ON u.id = uei.user_id
WHERE u.deleted_at IS NULL
AND LOWER(BTRIM(COALESCE(uei.provider, ''))) = 'wechat'
AND BTRIM(COALESCE(uei.provider_union_id, '')) <> ''
AND BTRIM(COALESCE(uei.provider_user_id, '')) <> ''
) AS legacy
JOIN auth_identities AS legacy_ai JOIN auth_identities AS legacy_ai
ON legacy_ai.user_id = legacy.user_id ON legacy_ai.user_id = legacy.user_id
AND legacy_ai.provider_type = 'wechat' AND legacy_ai.provider_type = 'wechat'
......
-- Intentionally left as a no-op. -- Auto-backfill untouched migration 110 signup-grant defaults to the corrected false value.
-- Legacy installs may have intentionally kept the original signup grant defaults, -- Rows still matching the migration-110 default payload and timestamp window are treated as
-- and we cannot distinguish those cases safely from untouched migration 110 rows. -- untouched legacy defaults; any remaining legacy true values are reported for manual review.
WITH migration_110 AS (
SELECT applied_at
FROM schema_migrations
WHERE filename = '110_pending_auth_and_provider_default_grants.sql'
),
providers AS (
SELECT provider_type
FROM (
VALUES ('email'), ('linuxdo'), ('oidc'), ('wechat')
) AS providers(provider_type)
),
legacy_provider_defaults AS (
SELECT providers.provider_type
FROM providers
CROSS JOIN migration_110
JOIN settings balance
ON balance.key = 'auth_source_default_' || providers.provider_type || '_balance'
JOIN settings concurrency
ON concurrency.key = 'auth_source_default_' || providers.provider_type || '_concurrency'
JOIN settings subscriptions
ON subscriptions.key = 'auth_source_default_' || providers.provider_type || '_subscriptions'
JOIN settings grant_on_signup
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
JOIN settings grant_on_first_bind
ON grant_on_first_bind.key = 'auth_source_default_' || providers.provider_type || '_grant_on_first_bind'
WHERE balance.value = '0'
AND concurrency.value = '5'
AND subscriptions.value = '[]'
AND grant_on_signup.value = 'true'
AND grant_on_first_bind.value = 'false'
AND balance.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND concurrency.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND subscriptions.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND grant_on_signup.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
AND grant_on_first_bind.updated_at BETWEEN migration_110.applied_at - INTERVAL '1 minute' AND migration_110.applied_at + INTERVAL '1 minute'
),
updated_signup_grants AS (
UPDATE settings
SET
value = 'false',
updated_at = NOW()
FROM legacy_provider_defaults
WHERE settings.key = 'auth_source_default_' || legacy_provider_defaults.provider_type || '_grant_on_signup'
AND settings.value = 'true'
RETURNING legacy_provider_defaults.provider_type
)
INSERT INTO auth_identity_migration_reports (report_type, report_key, details)
SELECT
'legacy_auth_source_signup_grant_review',
providers.provider_type,
jsonb_build_object(
'provider_type', providers.provider_type,
'current_value', grant_on_signup.value,
'auto_backfilled', FALSE,
'reason', 'legacy_true_default_not_auto_backfilled'
)
FROM providers
JOIN settings grant_on_signup
ON grant_on_signup.key = 'auth_source_default_' || providers.provider_type || '_grant_on_signup'
LEFT JOIN updated_signup_grants
ON updated_signup_grants.provider_type = providers.provider_type
WHERE grant_on_signup.value = 'true'
AND updated_signup_grants.provider_type IS NULL
ON CONFLICT (report_type, report_key) DO NOTHING;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment