Commit 36aed359 authored by IanShaw027's avatar IanShaw027
Browse files

fix(auth): harden oauth identity upgrade paths

parent 3d29f7c2
...@@ -3,7 +3,9 @@ package schema ...@@ -3,7 +3,9 @@ package schema
import ( import (
"testing" "testing"
"entgo.io/ent"
"entgo.io/ent/entc/load" "entgo.io/ent/entc/load"
"entgo.io/ent/schema/field"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) { ...@@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
userSchema := requireSchema(t, schemas, "User") userSchema := requireSchema(t, schemas, "User")
requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at")
signupSource := requireSchemaField(t, userSchema, "signup_source")
require.Equal(t, field.TypeString, signupSource.Info.Type)
require.True(t, signupSource.Default)
require.Equal(t, "email", signupSource.DefaultValue)
require.Equal(t, 1, signupSource.Validators)
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
for _, value := range []string{"email", "linuxdo", "wechat", "oidc"} {
require.NoError(t, validator(value))
}
require.Error(t, validator("github"))
} }
func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema {
...@@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { ...@@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
} }
} }
func requireSchemaField(t *testing.T, schema *load.Schema, name string) *load.Field {
t.Helper()
for _, schemaField := range schema.Fields {
if schemaField.Name == name {
return schemaField
}
}
require.Failf(t, "missing schema field", "schema %s should include field %s", schema.Name, name)
return nil
}
func requireStringFieldValidator(t *testing.T, fields []ent.Field, name string) func(string) error {
t.Helper()
for _, entField := range fields {
descriptor := entField.Descriptor()
if descriptor.Name != name {
continue
}
require.NotEmpty(t, descriptor.Validators, "field %s should include a validator", name)
validator, ok := descriptor.Validators[0].(func(string) error)
require.True(t, ok, "field %s validator should be func(string) error", name)
return validator
}
require.Failf(t, "missing field validator", "schema should include field %s", name)
return nil
}
func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) {
t.Helper() t.Helper()
......
package schema package schema
import ( import (
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins" "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/domain"
...@@ -73,7 +75,14 @@ func (User) Fields() []ent.Field { ...@@ -73,7 +75,14 @@ func (User) Fields() []ent.Field {
Optional(). Optional().
Nillable(), Nillable(),
field.String("signup_source"). field.String("signup_source").
MaxLen(20). Validate(func(value string) error {
switch value {
case "email", "linuxdo", "wechat", "oidc":
return nil
default:
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc")
}
}).
Default("email"), Default("email"),
field.Time("last_login_at"). field.Time("last_login_at").
Optional(). Optional().
......
...@@ -211,25 +211,27 @@ type WeChatConnectConfig struct { ...@@ -211,25 +211,27 @@ type WeChatConnectConfig struct {
} }
type OIDCConnectConfig struct { type OIDCConnectConfig struct {
Enabled bool `mapstructure:"enabled"` Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等
ClientID string `mapstructure:"client_id"` ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"` ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"` IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"` DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"` AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"` TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"` UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"` JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" Scopes string `mapstructure:"scopes"` // 默认 "openid email profile"
RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记)
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback)
TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none
UsePKCE bool `mapstructure:"use_pkce"` UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"` ValidateIDToken bool `mapstructure:"validate_id_token"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" UsePKCEExplicit bool `mapstructure:"-" yaml:"-"`
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 ValidateIDTokenExplicit bool `mapstructure:"-" yaml:"-"`
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256"
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120
RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。 // 为空时,服务端会尝试一组常见字段名。
...@@ -329,6 +331,14 @@ func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool { ...@@ -329,6 +331,14 @@ func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
return !hasNewEnv return !hasNewEnv
} }
func hasExplicitConfigOrEnv(configKey, envKey string) bool {
if viper.InConfig(configKey) {
return true
}
_, ok := os.LookupEnv(envKey)
return ok
}
func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) { func applyLegacyWeChatConnectEnvCompatibility(cfg *WeChatConnectConfig) {
if cfg == nil { if cfg == nil {
return return
...@@ -1262,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) { ...@@ -1262,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath) cfg.OIDC.UserInfoEmailPath = strings.TrimSpace(cfg.OIDC.UserInfoEmailPath)
cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath) cfg.OIDC.UserInfoIDPath = strings.TrimSpace(cfg.OIDC.UserInfoIDPath)
cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath) cfg.OIDC.UserInfoUsernamePath = strings.TrimSpace(cfg.OIDC.UserInfoUsernamePath)
cfg.OIDC.UsePKCEExplicit = hasExplicitConfigOrEnv("oidc_connect.use_pkce", "OIDC_CONNECT_USE_PKCE")
cfg.OIDC.ValidateIDTokenExplicit = hasExplicitConfigOrEnv("oidc_connect.validate_id_token", "OIDC_CONNECT_VALIDATE_ID_TOKEN")
cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins)
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
......
...@@ -254,6 +254,21 @@ func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) { ...@@ -254,6 +254,21 @@ func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, cfg.OIDC.UsePKCE) require.True(t, cfg.OIDC.UsePKCE)
require.True(t, cfg.OIDC.ValidateIDToken) require.True(t, cfg.OIDC.ValidateIDToken)
require.False(t, cfg.OIDC.UsePKCEExplicit)
require.False(t, cfg.OIDC.ValidateIDTokenExplicit)
}
func TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit(t *testing.T) {
resetViperWithJWTSecret(t)
t.Setenv("OIDC_CONNECT_USE_PKCE", "false")
t.Setenv("OIDC_CONNECT_VALIDATE_ID_TOKEN", "false")
cfg, err := Load()
require.NoError(t, err)
require.False(t, cfg.OIDC.UsePKCE)
require.False(t, cfg.OIDC.ValidateIDToken)
require.True(t, cfg.OIDC.UsePKCEExplicit)
require.True(t, cfg.OIDC.ValidateIDTokenExplicit)
} }
func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
......
...@@ -335,6 +335,75 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla ...@@ -335,6 +335,75 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
require.Equal(t, false, data["oidc_connect_validate_id_token"]) require.Equal(t, false, data["oidc_connect_validate_id_token"])
} }
func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "oidc-client",
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
service.SettingKeyOIDCConnectUserInfoIDPath: "",
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
},
}
svc := service.NewSettingService(repo, &config.Config{
Default: config.DefaultConfig{UserConcurrency: 5},
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
Scopes: "openid email profile",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"oidc_connect_enabled": true,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
}
func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{ repo := &settingHandlerRepoStub{
......
...@@ -355,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri ...@@ -355,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
} }
userEntity, err := client.User.Query(). userEntity, err := client.User.Query().
Where(dbuser.EmailEqualFold(email)). Where(userNormalizedEmailPredicate(email)).
Only(ctx) Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil { if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
} }
return userEntity, nil switch len(userEntity) {
case 0:
return nil, nil
case 1:
return userEntity[0], nil
default:
return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
}
} }
func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
...@@ -411,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession( ...@@ -411,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
completionResponse["choice_reason"] = "force_email_on_signup" completionResponse["choice_reason"] = "force_email_on_signup"
} }
var targetUserID *int64
if compatEmailUser != nil && compatEmailUser.ID > 0 {
targetUserID = &compatEmailUser.ID
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin, Intent: oauthIntentLogin,
Identity: identity, Identity: identity,
TargetUserID: targetUserID,
ResolvedEmail: resolvedChoiceEmail, ResolvedEmail: resolvedChoiceEmail,
RedirectTo: redirectTo, RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey, BrowserSessionKey: browserSessionKey,
...@@ -490,9 +501,13 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -490,9 +501,13 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return return
} }
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) client := h.entClient()
if err != nil { if client == nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return return
} }
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
...@@ -503,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { ...@@ -503,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) if err != nil {
response.ErrorFrom(c, err)
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { respondPendingOAuthBindingApplyError(c, err)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
......
...@@ -508,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test ...@@ -508,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
ctx := context.Background() ctx := context.Background()
existingUser, err := client.User.Create(). existingUser, err := client.User.Create().
SetEmail("legacy@example.com"). SetEmail(" Legacy@Example.com ").
SetUsername("legacy-user"). SetUsername("legacy-user").
SetPasswordHash("hash"). SetPasswordHash("hash").
SetRole(service.RoleUser). SetRole(service.RoleUser).
...@@ -539,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test ...@@ -539,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
Only(ctx) Only(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent) require.Equal(t, oauthIntentLogin, session.Intent)
require.Nil(t, session.TargetUserID) require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail) require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok) require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"]) require.Equal(t, "/dashboard", completion["redirect"])
require.Equal(t, oauthPendingChoiceStep, completion["step"]) require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, existingUser.Email, completion["email"]) require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
require.Equal(t, existingUser.Email, completion["existing_account_email"]) require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
require.Equal(t, true, completion["existing_account_bindable"]) require.Equal(t, true, completion["existing_account_bindable"])
require.Equal(t, "compat_email_match", completion["choice_reason"]) require.Equal(t, "compat_email_match", completion["choice_reason"])
_, hasAccessToken := completion["access_token"] _, hasAccessToken := completion["access_token"]
...@@ -943,6 +944,68 @@ func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *te ...@@ -943,6 +944,68 @@ func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *te
require.False(t, decision.AdoptAvatar) require.False(t, decision.AdoptAvatar)
} }
func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingOwner.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-conflict-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-conflict-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusConflict, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
userCount, err := client.User.Query().
Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler { func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper() t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg) handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
......
...@@ -519,7 +519,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) { ...@@ -519,7 +519,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
email := strings.TrimSpace(strings.ToLower(req.Email)) email := strings.TrimSpace(strings.ToLower(req.Email))
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil { if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -704,6 +704,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email ...@@ -704,6 +704,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
return matches[0], nil return matches[0], nil
} }
func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
if client == nil || session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity == nil || identity.UserID <= 0 {
return nil
}
activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
if err != nil {
return err
}
if activeOwner != nil {
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
return nil
}
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
if session == nil { if session == nil {
return nil return nil
...@@ -1206,6 +1238,38 @@ func consumePendingOAuthBrowserSessionTx( ...@@ -1206,6 +1238,38 @@ func consumePendingOAuthBrowserSessionTx(
return nil return nil
} }
func applyPendingOAuthAdoptionAndConsumeSession(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
userID int64,
) error {
if client == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
if session == nil || userID <= 0 {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
return err
}
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
return err
}
return tx.Commit()
}
func applyPendingOAuthAdoption( func applyPendingOAuthAdoption(
ctx context.Context, ctx context.Context,
client *dbent.Client, client *dbent.Client,
...@@ -1448,16 +1512,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState( ...@@ -1448,16 +1512,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
c *gin.Context, c *gin.Context,
client *dbent.Client, client *dbent.Client,
session *dbent.PendingAuthSession, session *dbent.PendingAuthSession,
targetUser *dbent.User,
email string, email string,
) (*dbent.PendingAuthSession, error) { ) (*dbent.PendingAuthSession, error) {
completionResponse := pendingOAuthChoiceCompletionResponse(session, email) completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
var targetUserID *int64
if targetUser != nil && targetUser.ID > 0 {
targetUserID = &targetUser.ID
}
session, err := updatePendingOAuthSessionProgress( session, err := updatePendingOAuthSessionProgress(
c.Request.Context(), c.Request.Context(),
client, client,
session, session,
strings.TrimSpace(session.Intent), strings.TrimSpace(session.Intent),
email, email,
nil, targetUserID,
completionResponse, completionResponse,
) )
if err != nil { if err != nil {
...@@ -1601,7 +1670,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ...@@ -1601,7 +1670,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
} }
} }
if existingUser != nil { if existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
...@@ -1624,7 +1693,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) ...@@ -1624,7 +1693,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
) )
if err != nil { if err != nil {
if errors.Is(err, service.ErrEmailExists) { if errors.Is(err, service.ErrEmailExists) {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, email) existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
if lookupErr != nil {
response.ErrorFrom(c, lookupErr)
return
}
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
......
...@@ -1045,7 +1045,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t * ...@@ -1045,7 +1045,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background() ctx := context.Background()
_, err := client.User.Create(). existingUser, err := client.User.Create().
SetEmail("owner@example.com"). SetEmail("owner@example.com").
SetUsername("owner-user"). SetUsername("owner-user").
SetPasswordHash("hash"). SetPasswordHash("hash").
...@@ -1099,7 +1099,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t * ...@@ -1099,7 +1099,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, oauthIntentLogin, storedSession.Intent) require.Equal(t, oauthIntentLogin, storedSession.Intent)
require.Nil(t, storedSession.TargetUserID) require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
require.Nil(t, storedSession.ConsumedAt) require.Nil(t, storedSession.ConsumedAt)
...@@ -1118,7 +1119,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te ...@@ -1118,7 +1119,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background() ctx := context.Background()
_, err := client.User.Create(). existingUser, err := client.User.Create().
SetEmail(" Owner@Example.com "). SetEmail(" Owner@Example.com ").
SetUsername("owner-user"). SetUsername("owner-user").
SetPasswordHash("hash"). SetPasswordHash("hash").
...@@ -1164,7 +1165,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te ...@@ -1164,7 +1165,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, storedSession.TargetUserID) require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
} }
...@@ -1172,7 +1174,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing ...@@ -1172,7 +1174,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790")
ctx := context.Background() ctx := context.Background()
_, err := client.User.Create(). existingUser, err := client.User.Create().
SetEmail("owner@example.com"). SetEmail("owner@example.com").
SetUsername("owner-user"). SetUsername("owner-user").
SetPasswordHash("hash"). SetPasswordHash("hash").
...@@ -1220,7 +1222,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing ...@@ -1220,7 +1222,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, oauthIntentLogin, storedSession.Intent) require.Equal(t, oauthIntentLogin, storedSession.Intent)
require.Nil(t, storedSession.TargetUserID) require.NotNil(t, storedSession.TargetUserID)
require.Equal(t, existingUser.ID, *storedSession.TargetUserID)
require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) require.Equal(t, "owner@example.com", storedSession.ResolvedEmail)
} }
......
...@@ -563,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession( ...@@ -563,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
if compatEmailUser != nil { if compatEmailUser != nil {
resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
} }
var targetUserID *int64
if compatEmailUser != nil && compatEmailUser.ID > 0 {
targetUserID = &compatEmailUser.ID
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin, Intent: oauthIntentLogin,
Identity: identity, Identity: identity,
TargetUserID: targetUserID,
ResolvedEmail: resolvedChoiceEmail, ResolvedEmail: resolvedChoiceEmail,
RedirectTo: redirectTo, RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey, BrowserSessionKey: browserSessionKey,
...@@ -643,9 +648,13 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { ...@@ -643,9 +648,13 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return return
} }
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) client := h.entClient()
if err != nil { if client == nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return return
} }
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
...@@ -656,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { ...@@ -656,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) if err != nil {
response.ErrorFrom(c, err)
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { respondPendingOAuthBindingApplyError(c, err)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
return return
} }
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie)
......
...@@ -438,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing ...@@ -438,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
Only(ctx) Only(ctx)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent) require.Equal(t, oauthIntentLogin, session.Intent)
require.Nil(t, session.TargetUserID) require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, existingUser.Email, session.ResolvedEmail) require.Equal(t, existingUser.Email, session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
...@@ -862,6 +863,69 @@ func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testi ...@@ -862,6 +863,69 @@ func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testi
require.False(t, decision.AdoptAvatar) require.False(t, decision.AdoptAvatar)
} }
func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingOwner.ID).
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-conflict-subject").
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("oidc-complete-conflict-session").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example.com").
SetProviderSubject("oidc-conflict-subject").
SetResolvedEmail("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid").
SetBrowserSessionKey("oidc-conflict-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "oidc_user",
"issuer": "https://issuer.example.com",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-conflict-browser")})
c.Request = req
handler.CompleteOIDCOAuthRegistration(c)
require.Equal(t, http.StatusConflict, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
userCount, err := client.User.Query().
Where(dbuser.EmailEQ("f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
type oidcProviderFixture struct { type oidcProviderFixture struct {
Subject string Subject string
PreferredUsername string PreferredUsername string
......
...@@ -576,6 +576,258 @@ FROM auth_identity_migration_reports ...@@ -576,6 +576,258 @@ FROM auth_identity_migration_reports
require.Equal(t, beforeCount, afterCount) require.Equal(t, beforeCount, afterCount)
} }
func TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoFirstUserID))
var linuxDoSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoSecondUserID))
var wechatFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatFirstUserID))
var wechatSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatSecondUserID))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
RETURNING id
`, linuxDoFirstUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
RETURNING id
`, linuxDoSecondUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
RETURNING id
`, wechatFirstUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
RETURNING id
`, wechatSecondUserID).Scan(new(int64)))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var linuxDoIdentityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-ambiguous-subject'
`).Scan(&linuxDoIdentityCount))
require.Zero(t, linuxDoIdentityCount)
var wechatIdentityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-ambiguous-subject'
`).Scan(&wechatIdentityCount))
require.Zero(t, wechatIdentityCount)
var wechatChannelCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_channels
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND channel = 'oa'
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
`).Scan(&wechatChannelCount))
require.Zero(t, wechatChannelCount)
}
func TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migration115SQL, err := os.ReadFile(migration115Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
prepareLegacyExternalIdentitiesTable(t, tx, ctx)
truncateAuthIdentityLegacyFixtureTables(t, tx, ctx)
var linuxDoFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoFirstUserID))
var linuxDoSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoSecondUserID))
var wechatFirstUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatFirstUserID))
var wechatSecondUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatSecondUserID))
var linuxDoFirstLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
RETURNING id
`, linuxDoFirstUserID).Scan(&linuxDoFirstLegacyID))
var linuxDoSecondLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
RETURNING id
`, linuxDoSecondUserID).Scan(&linuxDoSecondLegacyID))
var wechatFirstLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
RETURNING id
`, wechatFirstUserID).Scan(&wechatFirstLegacyID))
var wechatSecondLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
RETURNING id
`, wechatSecondUserID).Scan(&wechatSecondLegacyID))
_, err = tx.ExecContext(ctx, string(migration115SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var identityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
`).Scan(&identityCount))
require.Zero(t, identityCount)
var conflictReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&conflictReportCount))
require.Equal(t, 4, conflictReportCount)
var winnerAttributedReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
AND details ->> 'existing_identity_id' IS NOT NULL
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(linuxDoSecondLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatFirstLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatSecondLegacyID, 10)).Scan(&winnerAttributedReportCount))
require.Zero(t, winnerAttributedReportCount)
}
func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) { func TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121(t *testing.T) {
tx := testTx(t) tx := testTx(t)
ctx := context.Background() ctx := context.Background()
......
...@@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( ...@@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql" const nonTransactionalMigrationSuffix = "_notx.sql"
const paymentOrdersOutTradeNoUniqueMigration = "120_enforce_payment_orders_out_trade_no_unique_notx.sql"
const paymentOrdersOutTradeNoUniqueIndex = "paymentorder_out_trade_no_unique"
type migrationChecksumCompatibilityRule struct { type migrationChecksumCompatibilityRule struct {
fileChecksum string fileChecksum string
...@@ -65,9 +67,11 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil ...@@ -65,9 +67,11 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
"054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"), "054_drop_legacy_cache_columns.sql": newMigrationChecksumCompatibilityRule("82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"),
"061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"), "061_add_usage_log_request_type.sql": newMigrationChecksumCompatibilityRule("66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c", "08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0", "222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"),
"109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"), "109_auth_identity_compat_backfill.sql": newMigrationChecksumCompatibilityRule("2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3", "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"),
"115_auth_identity_legacy_external_backfill.sql": newMigrationChecksumCompatibilityRule("022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f", "4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"),
"116_auth_identity_legacy_external_safety_reports.sql": newMigrationChecksumCompatibilityRule("07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488", "f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"),
"118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"), "118_wechat_dual_mode_and_auth_source_defaults.sql": newMigrationChecksumCompatibilityRule("b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0", "e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"),
"119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"), "119_enforce_payment_orders_out_trade_no_unique.sql": newMigrationChecksumCompatibilityRule("0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e", "ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"),
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"), "120_enforce_payment_orders_out_trade_no_unique_notx.sql": newMigrationChecksumCompatibilityRule("34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074", "e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61", "707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22", "04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"),
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"), "123_fix_legacy_auth_source_grant_on_signup_defaults.sql": newMigrationChecksumCompatibilityRule("2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57", "6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"),
} }
...@@ -195,6 +199,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { ...@@ -195,6 +199,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
} }
if nonTx { if nonTx {
if err := prepareNonTransactionalMigration(ctx, db, name); err != nil {
return fmt.Errorf("prepare migration %s: %w", name, err)
}
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content) statements := splitSQLStatements(content)
...@@ -244,6 +252,88 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { ...@@ -244,6 +252,88 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return nil return nil
} }
func prepareNonTransactionalMigration(ctx context.Context, db *sql.DB, name string) error {
switch name {
case paymentOrdersOutTradeNoUniqueMigration:
return preparePaymentOrdersOutTradeNoUniqueMigration(ctx, db)
default:
return nil
}
}
func preparePaymentOrdersOutTradeNoUniqueMigration(ctx context.Context, db *sql.DB) error {
duplicates, err := findDuplicatePaymentOrderOutTradeNos(ctx, db)
if err != nil {
return fmt.Errorf("precheck duplicate out_trade_no: %w", err)
}
if len(duplicates) > 0 {
return fmt.Errorf(
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s",
paymentOrdersOutTradeNoUniqueMigration,
strings.Join(duplicates, ", "),
)
}
invalid, err := indexIsInvalid(ctx, db, paymentOrdersOutTradeNoUniqueIndex)
if err != nil {
return fmt.Errorf("check invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
}
if !invalid {
return nil
}
if _, err := db.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", paymentOrdersOutTradeNoUniqueIndex)); err != nil {
return fmt.Errorf("drop invalid index %s: %w", paymentOrdersOutTradeNoUniqueIndex, err)
}
return nil
}
func findDuplicatePaymentOrderOutTradeNos(ctx context.Context, db *sql.DB) ([]string, error) {
rows, err := db.QueryContext(ctx, `
SELECT out_trade_no, COUNT(*) AS duplicate_count
FROM payment_orders
WHERE out_trade_no <> ''
GROUP BY out_trade_no
HAVING COUNT(*) > 1
ORDER BY duplicate_count DESC, out_trade_no
LIMIT 5
`)
if err != nil {
return nil, err
}
defer rows.Close()
duplicates := make([]string, 0, 5)
for rows.Next() {
var outTradeNo string
var duplicateCount int
if err := rows.Scan(&outTradeNo, &duplicateCount); err != nil {
return nil, err
}
duplicates = append(duplicates, fmt.Sprintf("%s (count=%d)", outTradeNo, duplicateCount))
}
if err := rows.Err(); err != nil {
return nil, err
}
return duplicates, nil
}
func indexIsInvalid(ctx context.Context, db *sql.DB, indexName string) (bool, error) {
var invalid bool
err := db.QueryRowContext(ctx, `
SELECT EXISTS (
SELECT 1
FROM pg_class idx
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
JOIN pg_index i ON i.indexrelid = idx.oid
WHERE ns.nspname = 'public'
AND idx.relname = $1
AND NOT i.indisvalid
)
`, indexName).Scan(&invalid)
return invalid, err
}
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error { func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations") hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil { if err != nil {
......
...@@ -70,6 +70,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { ...@@ -70,6 +70,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
require.True(t, ok) require.True(t, ok)
}) })
t.Run("115历史checksum可兼容修复后的legacy external backfill", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"115_auth_identity_legacy_external_backfill.sql",
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f",
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f",
)
require.True(t, ok)
})
t.Run("116历史checksum可兼容修复后的legacy external safety reports", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"116_auth_identity_legacy_external_safety_reports.sql",
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877",
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488",
)
require.True(t, ok)
})
t.Run("119历史checksum可兼容占位文件", func(t *testing.T) { t.Run("119历史checksum可兼容占位文件", func(t *testing.T) {
ok := isMigrationChecksumCompatible( ok := isMigrationChecksumCompatible(
"119_enforce_payment_orders_out_trade_no_unique.sql", "119_enforce_payment_orders_out_trade_no_unique.sql",
...@@ -79,6 +97,21 @@ func TestIsMigrationChecksumCompatible(t *testing.T) { ...@@ -79,6 +97,21 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
require.True(t, ok) require.True(t, ok)
}) })
t.Run("120多个历史checksum都可兼容新的notx修复版本", func(t *testing.T) {
for _, dbChecksum := range []string{
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61",
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22",
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a",
} {
ok := isMigrationChecksumCompatible(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
dbChecksum,
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074",
)
require.True(t, ok)
}
})
t.Run("119未知checksum不兼容", func(t *testing.T) { t.Run("119未知checksum不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible( ok := isMigrationChecksumCompatible(
"119_enforce_payment_orders_out_trade_no_unique.sql", "119_enforce_payment_orders_out_trade_no_unique.sql",
......
...@@ -96,6 +96,8 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { ...@@ -96,6 +96,8 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) { func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
for _, name := range []string{ for _, name := range []string{
"115_auth_identity_legacy_external_backfill.sql",
"116_auth_identity_legacy_external_safety_reports.sql",
"118_wechat_dual_mode_and_auth_source_defaults.sql", "118_wechat_dual_mode_and_auth_source_defaults.sql",
"120_enforce_payment_orders_out_trade_no_unique_notx.sql", "120_enforce_payment_orders_out_trade_no_unique_notx.sql",
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql", "123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
......
...@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); ...@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require.NoError(t, mock.ExpectationsWereMet()) require.NoError(t, mock.ExpectationsWereMet())
} }
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
Data: []byte(`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "duplicate out_trade_no")
require.Contains(t, err.Error(), "dup-out-trade-no")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("paymentorder_out_trade_no_unique").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
Data: []byte(`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New() db, mock, err := sqlmock.New()
require.NoError(t, err) require.NoError(t, err)
......
...@@ -93,6 +93,19 @@ func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) ...@@ -93,6 +93,19 @@ func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T)
tx := testTx(t) tx := testTx(t)
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false) requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
requireConstraintDefinitionContains(
t,
tx,
"users",
"users_signup_source_check",
"signup_source",
"'email'",
"'linuxdo'",
"'wechat'",
"'oidc'",
)
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE") requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE") requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
...@@ -195,6 +208,45 @@ LIMIT 1 ...@@ -195,6 +208,45 @@ LIMIT 1
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable) require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
} }
func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
t.Helper()
var def string
err := tx.QueryRowContext(context.Background(), `
SELECT pg_get_constraintdef(c.oid)
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND c.conname = $2
`, table, constraint).Scan(&def)
require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
}
}
func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
t.Helper()
var columnDefault sql.NullString
err := tx.QueryRowContext(context.Background(), `
SELECT column_default
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`, table, column).Scan(&columnDefault)
require.NoError(t, err, "query column_default for %s.%s", table, column)
require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
for _, fragment := range fragments {
require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
}
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) { func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper() t.Helper()
......
...@@ -4,11 +4,15 @@ import ( ...@@ -4,11 +4,15 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"hash/fnv"
"reflect" "reflect"
"sort"
"strings" "strings"
"sync"
"time" "time"
"unsafe" "unsafe"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql" entsql "entgo.io/ent/dialect/sql"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/authidentity"
...@@ -120,6 +124,113 @@ type sqlQueryExecutor interface { ...@@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
} }
var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
type scopedKeyLockRegistry struct {
mu sync.Mutex
locks map[string]*scopedKeyLockEntry
}
type scopedKeyLockEntry struct {
mu sync.Mutex
refs int
}
func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
return &scopedKeyLockRegistry{
locks: make(map[string]*scopedKeyLockEntry),
}
}
func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
normalized := normalizeLockKeys(keys...)
if len(normalized) == 0 {
return func() {}
}
entries := make([]*scopedKeyLockEntry, 0, len(normalized))
r.mu.Lock()
for _, key := range normalized {
entry := r.locks[key]
if entry == nil {
entry = &scopedKeyLockEntry{}
r.locks[key] = entry
}
entry.refs++
entries = append(entries, entry)
}
r.mu.Unlock()
for _, entry := range entries {
entry.mu.Lock()
}
return func() {
for i := len(entries) - 1; i >= 0; i-- {
entries[i].mu.Unlock()
}
r.mu.Lock()
defer r.mu.Unlock()
for idx, key := range normalized {
entry := entries[idx]
entry.refs--
if entry.refs == 0 {
delete(r.locks, key)
}
}
}
}
func normalizeLockKeys(keys ...string) []string {
if len(keys) == 0 {
return nil
}
deduped := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
deduped[trimmed] = struct{}{}
}
if len(deduped) == 0 {
return nil
}
normalized := make([]string, 0, len(deduped))
for key := range deduped {
normalized = append(normalized, key)
}
sort.Strings(normalized)
return normalized
}
func advisoryLockHash(key string) int64 {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(key))
return int64(hasher.Sum64())
}
func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
release := repositoryScopedKeyLocks.lock(keys...)
normalized := normalizeLockKeys(keys...)
if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
return release, nil
}
for _, key := range normalized {
rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
if err != nil {
release()
return nil, err
}
_ = rows.Close()
}
return release, nil
}
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
if dbent.TxFromContext(ctx) != nil { if dbent.TxFromContext(ctx) != nil {
return fn(ctx) return fn(ctx)
...@@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA ...@@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return err return err
} }
} else { } else {
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
update := client.AuthIdentity.UpdateOneID(identity.ID) update := client.AuthIdentity.UpdateOneID(identity.ID)
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
update = update.SetProviderKey(targetProviderKey)
}
if input.Metadata != nil { if input.Metadata != nil {
update = update.SetMetadata(copyMetadata(input.Metadata)) update = update.SetMetadata(copyMetadata(input.Metadata))
} }
...@@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA ...@@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return err return err
} }
} else { } else {
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
update := client.AuthIdentityChannel.UpdateOneID(channel.ID). update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID) SetIdentityID(identity.ID)
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
update = update.SetProviderKey(targetProviderKey)
}
if input.ChannelMetadata != nil { if input.ChannelMetadata != nil {
update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
} }
...@@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string { ...@@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
return keys return keys
} }
func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
existingKey = strings.TrimSpace(existingKey)
requestedKey = strings.TrimSpace(requestedKey)
if providerType != "wechat" {
if requestedKey != "" {
return requestedKey
}
return existingKey
}
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
return "wechat-main"
}
if requestedKey != "" {
return requestedKey
}
return existingKey
}
func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerType != "wechat" {
return 0
}
switch {
case strings.EqualFold(providerKey, "wechat-main"):
return 0
case strings.EqualFold(providerKey, "wechat"):
return 2
default:
return 1
}
}
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity { func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
var selected *dbent.AuthIdentity
for _, record := range records { for _, record := range records {
if record.UserID == userID { if record.UserID != userID {
return record continue
}
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
} }
} }
return nil return selected
} }
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool { func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
...@@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) ...@@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64)
} }
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel { func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
var selected *dbent.AuthIdentityChannel
for _, record := range records { for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID == userID { if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
return record continue
}
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
} }
} }
return nil return selected
} }
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool { func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
...@@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, ...@@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
} }
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
client := clientFromContext(ctx, r.client) var result *dbent.IdentityAdoptionDecision
if input.IdentityID != nil && *input.IdentityID > 0 { err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
if _, err := client.IdentityAdoptionDecision.Update(). client := clientFromContext(txCtx, r.client)
Where( releaseLocks, err := lockRepositoryScopedKeys(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID), txCtx,
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) { client,
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID) txAwareSQLExecutor(txCtx, r.sql, r.client),
s.Where(entsql.Or( identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
entsql.IsNull(col), )
entsql.NEQ(col, input.PendingAuthSessionID), if err != nil {
)) return err
}), }
). defer releaseLocks()
ClearIdentityID().
Save(ctx); err != nil { if input.IdentityID != nil && *input.IdentityID > 0 {
return nil, err if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.NEQ(col, input.PendingAuthSessionID),
))
}),
).
ClearIdentityID().
Save(txCtx); err != nil {
return err
}
} }
}
current, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
now := time.Now().UTC()
if current == nil {
create := client.IdentityAdoptionDecision.Create(). create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID). SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName). SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar). SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(now) SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil { if input.IdentityID != nil && *input.IdentityID > 0 {
create = create.SetIdentityID(*input.IdentityID) create = create.SetIdentityID(*input.IdentityID)
} }
return create.Save(ctx)
decisionID, err := create.
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
UpdateNewValues().
ID(txCtx)
if err != nil {
return err
}
result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
return err
})
if err != nil {
return nil, err
} }
return result, nil
}
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID). func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
SetAdoptDisplayName(input.AdoptDisplayName). keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
SetAdoptAvatar(input.AdoptAvatar) if identityID != nil && *identityID > 0 {
if input.IdentityID != nil { keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
update = update.SetIdentityID(*input.IdentityID)
} }
return update.Save(ctx) return keys
} }
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
......
package repository
import (
"context"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
user := &service.User{
Email: "wechat-legacy@example.com",
Username: "wechat-legacy",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, user))
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetProviderSubject("union-legacy-123").
SetMetadata(map[string]any{"source": "legacy-alias"}).
Save(ctx)
require.NoError(t, err)
legacyChannel, err := client.AuthIdentityChannel.Create().
SetIdentityID(legacyIdentity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("oa").
SetChannelAppID("wx-app-legacy").
SetChannelSubject("openid-legacy-123").
SetMetadata(map[string]any{"scene": "legacy-alias"}).
Save(ctx)
require.NoError(t, err)
bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-legacy-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
Channel: "oa",
ChannelAppID: "wx-app-legacy",
ChannelSubject: "openid-legacy-123",
},
Metadata: map[string]any{"source": "canonical-bind"},
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
})
require.NoError(t, err)
require.NotNil(t, bound)
require.NotNil(t, bound.Identity)
require.NotNil(t, bound.Channel)
require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
require.Equal(t, legacyChannel.ID, bound.Channel.ID)
require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderSubjectEQ("union-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ChannelEQ("oa"),
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, channelCount)
}
func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
user := &service.User{
Email: "repo-adoption@example.com",
Username: "repo-adoption",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, user))
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-repo-adoption").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("pending-repo-adoption").
SetIntent("bind_current_user").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-repo-adoption").
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
SetLocalFlowState(map[string]any{"step": "pending"}).
Save(ctx)
require.NoError(t, err)
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type adoptionResult struct {
decision *dbent.IdentityAdoptionDecision
err error
}
input := IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
}
results := make(chan adoptionResult, 2)
go func() {
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
<-firstCreateStarted
go func() {
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
require.NoError(t, first.err)
require.NoError(t, second.err)
require.NotNil(t, first.decision)
require.NotNil(t, second.decision)
require.Equal(t, first.decision.ID, second.decision.ID)
count, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
loaded, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, loaded.IdentityID)
require.Equal(t, identity.ID, *loaded.IdentityID)
require.True(t, loaded.AdoptDisplayName)
require.True(t, loaded.AdoptAvatar)
}
...@@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if userIn == nil { if userIn == nil {
return nil return nil
} }
if err := r.ensureNormalizedEmailAvailable(ctx, 0, userIn.Email); err != nil {
return err
}
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化, // 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。 // 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
...@@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
} }
var txClient *dbent.Client var txClient *dbent.Client
txCtx := ctx
if err == nil { if err == nil {
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
txClient = tx.Client() txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else { } else {
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil { if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
...@@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
} }
} }
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
return err
}
created, err := txClient.User.Create(). created, err := txClient.User.Create().
SetEmail(userIn.Email). SetEmail(userIn.Email).
SetUsername(userIn.Username). SetUsername(userIn.Username).
...@@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt). SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt). SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx) Save(txCtx)
if err != nil { if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
} }
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil { if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err return err
} }
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil { if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err return err
} }
...@@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if userIn == nil { if userIn == nil {
return nil return nil
} }
if err := r.ensureNormalizedEmailAvailable(ctx, userIn.ID, userIn.Email); err != nil {
return err
}
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。 // 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
tx, err := r.client.Tx(ctx) tx, err := r.client.Tx(ctx)
...@@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
} }
var txClient *dbent.Client var txClient *dbent.Client
txCtx := ctx
if err == nil { if err == nil {
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
txClient = tx.Client() txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else { } else {
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。 // 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil { if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
...@@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client txClient = r.client
} }
} }
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
return err
}
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil) return translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
...@@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if userIn.BalanceNotifyThreshold == nil { if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold() updateOp = updateOp.ClearBalanceNotifyThreshold()
} }
updated, err := updateOp.Save(ctx) updated, err := updateOp.Save(txCtx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
} }
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil { if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err return err
} }
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil { if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err return err
} }
...@@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, ...@@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
} }
func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error { func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, userID int64, email string) error {
matches, err := r.client.User.Query(). return ensureNormalizedEmailAvailableWithClient(ctx, clientFromContext(ctx, r.client), userID, email)
}
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
client = clientFromContext(ctx, client)
if client == nil {
return nil
}
matches, err := client.User.Query().
Where(userEmailLookupPredicate(email)). Where(userEmailLookupPredicate(email)).
All(ctx) All(ctx)
if err != nil { if err != nil {
...@@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use ...@@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use
} }
func userEmailLookupPredicate(email string) predicate.User { func userEmailLookupPredicate(email string) predicate.User {
normalized := strings.ToLower(strings.TrimSpace(email)) normalized := normalizeEmailLookupValue(email)
if normalized == "" { if normalized == "" {
return dbuser.EmailEQ(email) return dbuser.EmailEQ(email)
} }
...@@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User { ...@@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User {
}) })
} }
func normalizeEmailLookupValue(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
func normalizedEmailUniquenessLockKey(email string) string {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return ""
}
return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
err := client.UserAllowedGroup.Create(). err := client.UserAllowedGroup.Create().
...@@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { ...@@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
} }
func userSignupSourceOrDefault(signupSource string) string { func userSignupSourceOrDefault(signupSource string) string {
signupSource = strings.TrimSpace(signupSource) switch strings.TrimSpace(strings.ToLower(signupSource)) {
if signupSource == "" { case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email" return "email"
} }
return signupSource
} }
// marshalExtraEmails serializes notify email entries to JSON for storage. // marshalExtraEmails serializes notify email entries to JSON for storage.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment