Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
36aed359
Commit
36aed359
authored
Apr 22, 2026
by
IanShaw027
Browse files
fix(auth): harden oauth identity upgrade paths
parent
3d29f7c2
Changes
32
Hide whitespace changes
Inline
Side-by-side
backend/ent/schema/auth_identity_schema_test.go
View file @
36aed359
...
...
@@ -3,7 +3,9 @@ package schema
import
(
"testing"
"entgo.io/ent"
"entgo.io/ent/entc/load"
"entgo.io/ent/schema/field"
"github.com/stretchr/testify/require"
)
...
...
@@ -74,6 +76,17 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
userSchema
:=
requireSchema
(
t
,
schemas
,
"User"
)
requireSchemaFields
(
t
,
userSchema
,
"signup_source"
,
"last_login_at"
,
"last_active_at"
)
signupSource
:=
requireSchemaField
(
t
,
userSchema
,
"signup_source"
)
require
.
Equal
(
t
,
field
.
TypeString
,
signupSource
.
Info
.
Type
)
require
.
True
(
t
,
signupSource
.
Default
)
require
.
Equal
(
t
,
"email"
,
signupSource
.
DefaultValue
)
require
.
Equal
(
t
,
1
,
signupSource
.
Validators
)
validator
:=
requireStringFieldValidator
(
t
,
User
{}
.
Fields
(),
"signup_source"
)
for
_
,
value
:=
range
[]
string
{
"email"
,
"linuxdo"
,
"wechat"
,
"oidc"
}
{
require
.
NoError
(
t
,
validator
(
value
))
}
require
.
Error
(
t
,
validator
(
"github"
))
}
func
requireSchema
(
t
*
testing
.
T
,
schemas
map
[
string
]
*
load
.
Schema
,
name
string
)
*
load
.
Schema
{
...
...
@@ -98,6 +111,37 @@ func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) {
}
}
func
requireSchemaField
(
t
*
testing
.
T
,
schema
*
load
.
Schema
,
name
string
)
*
load
.
Field
{
t
.
Helper
()
for
_
,
schemaField
:=
range
schema
.
Fields
{
if
schemaField
.
Name
==
name
{
return
schemaField
}
}
require
.
Failf
(
t
,
"missing schema field"
,
"schema %s should include field %s"
,
schema
.
Name
,
name
)
return
nil
}
func
requireStringFieldValidator
(
t
*
testing
.
T
,
fields
[]
ent
.
Field
,
name
string
)
func
(
string
)
error
{
t
.
Helper
()
for
_
,
entField
:=
range
fields
{
descriptor
:=
entField
.
Descriptor
()
if
descriptor
.
Name
!=
name
{
continue
}
require
.
NotEmpty
(
t
,
descriptor
.
Validators
,
"field %s should include a validator"
,
name
)
validator
,
ok
:=
descriptor
.
Validators
[
0
]
.
(
func
(
string
)
error
)
require
.
True
(
t
,
ok
,
"field %s validator should be func(string) error"
,
name
)
return
validator
}
require
.
Failf
(
t
,
"missing field validator"
,
"schema should include field %s"
,
name
)
return
nil
}
func
requireHasUniqueIndex
(
t
*
testing
.
T
,
schema
*
load
.
Schema
,
fields
...
string
)
{
t
.
Helper
()
...
...
backend/ent/schema/user.go
View file @
36aed359
package
schema
import
(
"fmt"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/internal/domain"
...
...
@@ -73,7 +75,14 @@ func (User) Fields() []ent.Field {
Optional
()
.
Nillable
(),
field
.
String
(
"signup_source"
)
.
MaxLen
(
20
)
.
Validate
(
func
(
value
string
)
error
{
switch
value
{
case
"email"
,
"linuxdo"
,
"wechat"
,
"oidc"
:
return
nil
default
:
return
fmt
.
Errorf
(
"must be one of email, linuxdo, wechat, oidc"
)
}
})
.
Default
(
"email"
),
field
.
Time
(
"last_login_at"
)
.
Optional
()
.
...
...
backend/internal/config/config.go
View file @
36aed359
...
...
@@ -211,25 +211,27 @@ type WeChatConnectConfig struct {
}
type
OIDCConnectConfig
struct
{
Enabled
bool
`mapstructure:"enabled"`
ProviderName
string
`mapstructure:"provider_name"`
// 显示名: "Keycloak" 等
ClientID
string
`mapstructure:"client_id"`
ClientSecret
string
`mapstructure:"client_secret"`
IssuerURL
string
`mapstructure:"issuer_url"`
DiscoveryURL
string
`mapstructure:"discovery_url"`
AuthorizeURL
string
`mapstructure:"authorize_url"`
TokenURL
string
`mapstructure:"token_url"`
UserInfoURL
string
`mapstructure:"userinfo_url"`
JWKSURL
string
`mapstructure:"jwks_url"`
Scopes
string
`mapstructure:"scopes"`
// 默认 "openid email profile"
RedirectURL
string
`mapstructure:"redirect_url"`
// 后端回调地址(需在提供方后台登记)
FrontendRedirectURL
string
`mapstructure:"frontend_redirect_url"`
// 前端接收 token 的路由(默认:/auth/oidc/callback)
TokenAuthMethod
string
`mapstructure:"token_auth_method"`
// client_secret_post / client_secret_basic / none
UsePKCE
bool
`mapstructure:"use_pkce"`
ValidateIDToken
bool
`mapstructure:"validate_id_token"`
AllowedSigningAlgs
string
`mapstructure:"allowed_signing_algs"`
// 默认 "RS256,ES256,PS256"
ClockSkewSeconds
int
`mapstructure:"clock_skew_seconds"`
// 默认 120
RequireEmailVerified
bool
`mapstructure:"require_email_verified"`
// 默认 false
Enabled
bool
`mapstructure:"enabled"`
ProviderName
string
`mapstructure:"provider_name"`
// 显示名: "Keycloak" 等
ClientID
string
`mapstructure:"client_id"`
ClientSecret
string
`mapstructure:"client_secret"`
IssuerURL
string
`mapstructure:"issuer_url"`
DiscoveryURL
string
`mapstructure:"discovery_url"`
AuthorizeURL
string
`mapstructure:"authorize_url"`
TokenURL
string
`mapstructure:"token_url"`
UserInfoURL
string
`mapstructure:"userinfo_url"`
JWKSURL
string
`mapstructure:"jwks_url"`
Scopes
string
`mapstructure:"scopes"`
// 默认 "openid email profile"
RedirectURL
string
`mapstructure:"redirect_url"`
// 后端回调地址(需在提供方后台登记)
FrontendRedirectURL
string
`mapstructure:"frontend_redirect_url"`
// 前端接收 token 的路由(默认:/auth/oidc/callback)
TokenAuthMethod
string
`mapstructure:"token_auth_method"`
// client_secret_post / client_secret_basic / none
UsePKCE
bool
`mapstructure:"use_pkce"`
ValidateIDToken
bool
`mapstructure:"validate_id_token"`
UsePKCEExplicit
bool
`mapstructure:"-" yaml:"-"`
ValidateIDTokenExplicit
bool
`mapstructure:"-" yaml:"-"`
AllowedSigningAlgs
string
`mapstructure:"allowed_signing_algs"`
// 默认 "RS256,ES256,PS256"
ClockSkewSeconds
int
`mapstructure:"clock_skew_seconds"`
// 默认 120
RequireEmailVerified
bool
`mapstructure:"require_email_verified"`
// 默认 false
// 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。
// 为空时,服务端会尝试一组常见字段名。
...
...
@@ -329,6 +331,14 @@ func shouldApplyLegacyWeChatEnv(configKey, envKey string) bool {
return
!
hasNewEnv
}
func
hasExplicitConfigOrEnv
(
configKey
,
envKey
string
)
bool
{
if
viper
.
InConfig
(
configKey
)
{
return
true
}
_
,
ok
:=
os
.
LookupEnv
(
envKey
)
return
ok
}
func
applyLegacyWeChatConnectEnvCompatibility
(
cfg
*
WeChatConnectConfig
)
{
if
cfg
==
nil
{
return
...
...
@@ -1262,6 +1272,8 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg
.
OIDC
.
UserInfoEmailPath
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
UserInfoEmailPath
)
cfg
.
OIDC
.
UserInfoIDPath
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
UserInfoIDPath
)
cfg
.
OIDC
.
UserInfoUsernamePath
=
strings
.
TrimSpace
(
cfg
.
OIDC
.
UserInfoUsernamePath
)
cfg
.
OIDC
.
UsePKCEExplicit
=
hasExplicitConfigOrEnv
(
"oidc_connect.use_pkce"
,
"OIDC_CONNECT_USE_PKCE"
)
cfg
.
OIDC
.
ValidateIDTokenExplicit
=
hasExplicitConfigOrEnv
(
"oidc_connect.validate_id_token"
,
"OIDC_CONNECT_VALIDATE_ID_TOKEN"
)
cfg
.
Dashboard
.
KeyPrefix
=
strings
.
TrimSpace
(
cfg
.
Dashboard
.
KeyPrefix
)
cfg
.
CORS
.
AllowedOrigins
=
normalizeStringSlice
(
cfg
.
CORS
.
AllowedOrigins
)
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
=
normalizeStringSlice
(
cfg
.
Security
.
ResponseHeaders
.
AdditionalAllowed
)
...
...
backend/internal/config/config_test.go
View file @
36aed359
...
...
@@ -254,6 +254,21 @@ func TestLoadDefaultOIDCSecurityDefaults(t *testing.T) {
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
cfg
.
OIDC
.
UsePKCE
)
require
.
True
(
t
,
cfg
.
OIDC
.
ValidateIDToken
)
require
.
False
(
t
,
cfg
.
OIDC
.
UsePKCEExplicit
)
require
.
False
(
t
,
cfg
.
OIDC
.
ValidateIDTokenExplicit
)
}
func
TestLoadExplicitOIDCSecurityDefaultsFromEnvMarksFlagsExplicit
(
t
*
testing
.
T
)
{
resetViperWithJWTSecret
(
t
)
t
.
Setenv
(
"OIDC_CONNECT_USE_PKCE"
,
"false"
)
t
.
Setenv
(
"OIDC_CONNECT_VALIDATE_ID_TOKEN"
,
"false"
)
cfg
,
err
:=
Load
()
require
.
NoError
(
t
,
err
)
require
.
False
(
t
,
cfg
.
OIDC
.
UsePKCE
)
require
.
False
(
t
,
cfg
.
OIDC
.
ValidateIDToken
)
require
.
True
(
t
,
cfg
.
OIDC
.
UsePKCEExplicit
)
require
.
True
(
t
,
cfg
.
OIDC
.
ValidateIDTokenExplicit
)
}
func
TestLoadForcedCodexInstructionsTemplate
(
t
*
testing
.
T
)
{
...
...
backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
View file @
36aed359
...
...
@@ -335,6 +335,75 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
require
.
Equal
(
t
,
false
,
data
[
"oidc_connect_validate_id_token"
])
}
func
TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
settingHandlerRepoStub
{
values
:
map
[
string
]
string
{
service
.
SettingKeyPromoCodeEnabled
:
"true"
,
service
.
SettingKeyOIDCConnectEnabled
:
"true"
,
service
.
SettingKeyOIDCConnectProviderName
:
"OIDC"
,
service
.
SettingKeyOIDCConnectClientID
:
"oidc-client"
,
service
.
SettingKeyOIDCConnectClientSecret
:
"oidc-secret"
,
service
.
SettingKeyOIDCConnectIssuerURL
:
"https://issuer.example.com"
,
service
.
SettingKeyOIDCConnectAuthorizeURL
:
"https://issuer.example.com/auth"
,
service
.
SettingKeyOIDCConnectTokenURL
:
"https://issuer.example.com/token"
,
service
.
SettingKeyOIDCConnectUserInfoURL
:
"https://issuer.example.com/userinfo"
,
service
.
SettingKeyOIDCConnectJWKSURL
:
"https://issuer.example.com/jwks"
,
service
.
SettingKeyOIDCConnectScopes
:
"openid email profile"
,
service
.
SettingKeyOIDCConnectRedirectURL
:
"https://example.com/api/v1/auth/oauth/oidc/callback"
,
service
.
SettingKeyOIDCConnectFrontendRedirectURL
:
"/auth/oidc/callback"
,
service
.
SettingKeyOIDCConnectTokenAuthMethod
:
"client_secret_post"
,
service
.
SettingKeyOIDCConnectAllowedSigningAlgs
:
"RS256"
,
service
.
SettingKeyOIDCConnectClockSkewSeconds
:
"120"
,
service
.
SettingKeyOIDCConnectRequireEmailVerified
:
"false"
,
service
.
SettingKeyOIDCConnectUserInfoEmailPath
:
""
,
service
.
SettingKeyOIDCConnectUserInfoIDPath
:
""
,
service
.
SettingKeyOIDCConnectUserInfoUsernamePath
:
""
,
},
}
svc
:=
service
.
NewSettingService
(
repo
,
&
config
.
Config
{
Default
:
config
.
DefaultConfig
{
UserConcurrency
:
5
},
OIDC
:
config
.
OIDCConnectConfig
{
Enabled
:
true
,
ProviderName
:
"OIDC"
,
ClientID
:
"oidc-client"
,
ClientSecret
:
"oidc-secret"
,
IssuerURL
:
"https://issuer.example.com"
,
AuthorizeURL
:
"https://issuer.example.com/auth"
,
TokenURL
:
"https://issuer.example.com/token"
,
UserInfoURL
:
"https://issuer.example.com/userinfo"
,
JWKSURL
:
"https://issuer.example.com/jwks"
,
Scopes
:
"openid email profile"
,
RedirectURL
:
"https://example.com/api/v1/auth/oauth/oidc/callback"
,
FrontendRedirectURL
:
"/auth/oidc/callback"
,
TokenAuthMethod
:
"client_secret_post"
,
UsePKCE
:
true
,
ValidateIDToken
:
true
,
AllowedSigningAlgs
:
"RS256"
,
ClockSkewSeconds
:
120
,
},
})
handler
:=
NewSettingHandler
(
svc
,
nil
,
nil
,
nil
,
nil
,
nil
)
body
:=
map
[
string
]
any
{
"promo_code_enabled"
:
true
,
"oidc_connect_enabled"
:
true
,
}
rawBody
,
err
:=
json
.
Marshal
(
body
)
require
.
NoError
(
t
,
err
)
rec
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
rec
)
c
.
Request
=
httptest
.
NewRequest
(
http
.
MethodPut
,
"/api/v1/admin/settings"
,
bytes
.
NewReader
(
rawBody
))
c
.
Request
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
handler
.
UpdateSettings
(
c
)
require
.
Equal
(
t
,
http
.
StatusOK
,
rec
.
Code
)
require
.
Equal
(
t
,
"false"
,
repo
.
values
[
service
.
SettingKeyOIDCConnectUsePKCE
])
require
.
Equal
(
t
,
"false"
,
repo
.
values
[
service
.
SettingKeyOIDCConnectValidateIDToken
])
}
func
TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource
(
t
*
testing
.
T
)
{
gin
.
SetMode
(
gin
.
TestMode
)
repo
:=
&
settingHandlerRepoStub
{
...
...
backend/internal/handler/auth_linuxdo_oauth.go
View file @
36aed359
...
...
@@ -355,15 +355,20 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
}
userEntity
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEqualFold
(
email
))
.
Only
(
ctx
)
Where
(
userNormalizedEmailPredicate
(
email
))
.
Order
(
dbent
.
Asc
(
dbuser
.
FieldID
))
.
All
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
nil
}
return
nil
,
infraerrors
.
InternalServer
(
"COMPAT_EMAIL_LOOKUP_FAILED"
,
"failed to look up compat email user"
)
.
WithCause
(
err
)
}
return
userEntity
,
nil
switch
len
(
userEntity
)
{
case
0
:
return
nil
,
nil
case
1
:
return
userEntity
[
0
],
nil
default
:
return
nil
,
infraerrors
.
Conflict
(
"USER_EMAIL_CONFLICT"
,
"normalized email matched multiple users"
)
}
}
func
(
h
*
AuthHandler
)
createLinuxDoOAuthChoicePendingSession
(
...
...
@@ -411,9 +416,15 @@ func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
completionResponse
[
"choice_reason"
]
=
"force_email_on_signup"
}
var
targetUserID
*
int64
if
compatEmailUser
!=
nil
&&
compatEmailUser
.
ID
>
0
{
targetUserID
=
&
compatEmailUser
.
ID
}
return
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
Intent
:
oauthIntentLogin
,
Identity
:
identity
,
TargetUserID
:
targetUserID
,
ResolvedEmail
:
resolvedChoiceEmail
,
RedirectTo
:
redirectTo
,
BrowserSessionKey
:
browserSessionKey
,
...
...
@@ -490,9 +501,13 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
client
:=
h
.
entClient
()
if
client
==
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
))
return
}
if
err
:=
ensurePendingOAuthRegistrationIdentityAvailable
(
c
.
Request
.
Context
(),
client
,
session
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
return
}
decision
,
err
:=
h
.
ensurePendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
...
...
@@ -503,17 +518,16 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
if
err
:=
applyPendingOAuthAdoption
(
c
.
Request
.
Context
(),
h
.
entClient
(),
h
.
authService
,
h
.
userService
,
session
,
decision
,
&
user
.
ID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_ADOPTION_APPLY_FAILED"
,
"failed to apply oauth profile adoption"
)
.
WithCause
(
err
))
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
if
_
,
err
:=
pendingSvc
.
ConsumeBrowserSession
(
c
.
Request
.
Context
(),
sessionToken
,
browserSessionKey
);
err
!=
nil
{
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
response
.
ErrorFrom
(
c
,
err
)
if
err
:=
applyPendingOAuthAdoptionAndConsumeSession
(
c
.
Request
.
Context
(),
client
,
h
.
authService
,
h
.
userService
,
session
,
decision
,
user
.
ID
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
...
...
backend/internal/handler/auth_linuxdo_oauth_test.go
View file @
36aed359
...
...
@@ -508,7 +508,7 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
ctx
:=
context
.
Background
()
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"
l
egacy@
e
xample.com"
)
.
SetEmail
(
"
L
egacy@
E
xample.com
"
)
.
SetUsername
(
"legacy-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
...
...
@@ -539,16 +539,17 @@ func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *test
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
session
.
Intent
)
require
.
Nil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
Email
,
session
.
ResolvedEmail
)
require
.
NotNil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
session
.
TargetUserID
)
require
.
Equal
(
t
,
strings
.
TrimSpace
(
existingUser
.
Email
),
session
.
ResolvedEmail
)
require
.
Equal
(
t
,
"legacy@example.com"
,
session
.
UpstreamIdentityClaims
[
"compat_email"
])
completion
,
ok
:=
session
.
LocalFlowState
[
oauthCompletionResponseKey
]
.
(
map
[
string
]
any
)
require
.
True
(
t
,
ok
)
require
.
Equal
(
t
,
"/dashboard"
,
completion
[
"redirect"
])
require
.
Equal
(
t
,
oauthPendingChoiceStep
,
completion
[
"step"
])
require
.
Equal
(
t
,
existingUser
.
Email
,
completion
[
"email"
])
require
.
Equal
(
t
,
existingUser
.
Email
,
completion
[
"existing_account_email"
])
require
.
Equal
(
t
,
strings
.
TrimSpace
(
existingUser
.
Email
)
,
completion
[
"email"
])
require
.
Equal
(
t
,
strings
.
TrimSpace
(
existingUser
.
Email
)
,
completion
[
"existing_account_email"
])
require
.
Equal
(
t
,
true
,
completion
[
"existing_account_bindable"
])
require
.
Equal
(
t
,
"compat_email_match"
,
completion
[
"choice_reason"
])
_
,
hasAccessToken
:=
completion
[
"access_token"
]
...
...
@@ -943,6 +944,68 @@ func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *te
require
.
False
(
t
,
decision
.
AdoptAvatar
)
}
func
TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
existingOwner
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingOwner
.
ID
)
.
SetProviderType
(
"linuxdo"
)
.
SetProviderKey
(
"linuxdo"
)
.
SetProviderSubject
(
"linuxdo-conflict-subject"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"linuxdo-complete-conflict-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"linuxdo"
)
.
SetProviderKey
(
"linuxdo"
)
.
SetProviderSubject
(
"linuxdo-conflict-subject"
)
.
SetResolvedEmail
(
"linuxdo-conflict-subject@linuxdo-connect.invalid"
)
.
SetBrowserSessionKey
(
"linuxdo-conflict-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"linuxdo_user"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/linuxdo/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"linuxdo-conflict-browser"
)})
c
.
Request
=
req
handler
.
CompleteLinuxDoOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
recorder
.
Code
)
payload
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
payload
[
"reason"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
"linuxdo-conflict-subject@linuxdo-connect.invalid"
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
func
newLinuxDoOAuthTestHandler
(
t
*
testing
.
T
,
invitationEnabled
bool
,
oauthCfg
config
.
LinuxDoConnectConfig
)
*
AuthHandler
{
t
.
Helper
()
handler
,
_
:=
newLinuxDoOAuthHandlerAndClient
(
t
,
invitationEnabled
,
oauthCfg
)
...
...
backend/internal/handler/auth_oauth_pending_flow.go
View file @
36aed359
...
...
@@ -519,7 +519,7 @@ func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
email
:=
strings
.
TrimSpace
(
strings
.
ToLower
(
req
.
Email
))
if
existingUser
,
err
:=
findUserByNormalizedEmail
(
c
.
Request
.
Context
(),
client
,
email
);
err
==
nil
&&
existingUser
!=
nil
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -704,6 +704,38 @@ func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email
return
matches
[
0
],
nil
}
func
ensurePendingOAuthRegistrationIdentityAvailable
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
session
*
dbent
.
PendingAuthSession
)
error
{
if
client
==
nil
||
session
==
nil
{
return
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
strings
.
TrimSpace
(
session
.
ProviderType
)),
authidentity
.
ProviderKeyEQ
(
strings
.
TrimSpace
(
session
.
ProviderKey
)),
authidentity
.
ProviderSubjectEQ
(
strings
.
TrimSpace
(
session
.
ProviderSubject
)),
)
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
}
return
err
}
if
identity
==
nil
||
identity
.
UserID
<=
0
{
return
nil
}
activeOwner
,
err
:=
findActiveUserByID
(
ctx
,
client
,
identity
.
UserID
)
if
err
!=
nil
{
return
err
}
if
activeOwner
!=
nil
{
return
infraerrors
.
Conflict
(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
"auth identity already belongs to another user"
)
}
return
nil
}
func
oauthIdentityIssuer
(
session
*
dbent
.
PendingAuthSession
)
*
string
{
if
session
==
nil
{
return
nil
...
...
@@ -1206,6 +1238,38 @@ func consumePendingOAuthBrowserSessionTx(
return
nil
}
func
applyPendingOAuthAdoptionAndConsumeSession
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
authService
*
service
.
AuthService
,
userService
*
service
.
UserService
,
session
*
dbent
.
PendingAuthSession
,
decision
*
dbent
.
IdentityAdoptionDecision
,
userID
int64
,
)
error
{
if
client
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
)
}
if
session
==
nil
||
userID
<=
0
{
return
infraerrors
.
BadRequest
(
"PENDING_AUTH_SESSION_INVALID"
,
"pending auth registration context is invalid"
)
}
tx
,
err
:=
client
.
Tx
(
ctx
)
if
err
!=
nil
{
return
err
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txCtx
:=
dbent
.
NewTxContext
(
ctx
,
tx
)
if
err
:=
applyPendingOAuthAdoption
(
txCtx
,
client
,
authService
,
userService
,
session
,
decision
,
&
userID
);
err
!=
nil
{
return
err
}
if
err
:=
consumePendingOAuthBrowserSessionTx
(
txCtx
,
tx
,
session
);
err
!=
nil
{
return
err
}
return
tx
.
Commit
()
}
func
applyPendingOAuthAdoption
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
...
...
@@ -1448,16 +1512,21 @@ func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
c
*
gin
.
Context
,
client
*
dbent
.
Client
,
session
*
dbent
.
PendingAuthSession
,
targetUser
*
dbent
.
User
,
email
string
,
)
(
*
dbent
.
PendingAuthSession
,
error
)
{
completionResponse
:=
pendingOAuthChoiceCompletionResponse
(
session
,
email
)
var
targetUserID
*
int64
if
targetUser
!=
nil
&&
targetUser
.
ID
>
0
{
targetUserID
=
&
targetUser
.
ID
}
session
,
err
:=
updatePendingOAuthSessionProgress
(
c
.
Request
.
Context
(),
client
,
session
,
strings
.
TrimSpace
(
session
.
Intent
),
email
,
nil
,
targetUserID
,
completionResponse
,
)
if
err
!=
nil
{
...
...
@@ -1601,7 +1670,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
}
if
existingUser
!=
nil
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
@@ -1624,7 +1693,12 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
)
if
err
!=
nil
{
if
errors
.
Is
(
err
,
service
.
ErrEmailExists
)
{
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
email
)
existingUser
,
lookupErr
:=
findUserByNormalizedEmail
(
c
.
Request
.
Context
(),
client
,
email
)
if
lookupErr
!=
nil
{
response
.
ErrorFrom
(
c
,
lookupErr
)
return
}
session
,
err
=
h
.
transitionPendingOAuthAccountToChoiceState
(
c
,
client
,
session
,
existingUser
,
email
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
...
...
backend/internal/handler/auth_oauth_pending_flow_test.go
View file @
36aed359
...
...
@@ -1045,7 +1045,7 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
...
...
@@ -1099,7 +1099,8 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsChoicePendingSessionState(t *
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
...
...
@@ -1118,7 +1119,7 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
" Owner@Example.com "
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
...
...
@@ -1164,7 +1165,8 @@ func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *te
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
}
...
...
@@ -1172,7 +1174,7 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
handler
,
client
:=
newOAuthPendingFlowTestHandlerWithEmailVerification
(
t
,
false
,
"owner@example.com"
,
"135790"
)
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
existingUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
...
...
@@ -1220,7 +1222,8 @@ func TestSendPendingOAuthVerifyCodeExistingEmailReturnsBindLoginState(t *testing
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
storedSession
.
Intent
)
require
.
Nil
(
t
,
storedSession
.
TargetUserID
)
require
.
NotNil
(
t
,
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
storedSession
.
TargetUserID
)
require
.
Equal
(
t
,
"owner@example.com"
,
storedSession
.
ResolvedEmail
)
}
...
...
backend/internal/handler/auth_oidc_oauth.go
View file @
36aed359
...
...
@@ -563,10 +563,15 @@ func (h *AuthHandler) createOIDCOAuthChoicePendingSession(
if
compatEmailUser
!=
nil
{
resolvedChoiceEmail
=
strings
.
TrimSpace
(
compatEmailUser
.
Email
)
}
var
targetUserID
*
int64
if
compatEmailUser
!=
nil
&&
compatEmailUser
.
ID
>
0
{
targetUserID
=
&
compatEmailUser
.
ID
}
return
h
.
createOAuthPendingSession
(
c
,
oauthPendingSessionPayload
{
Intent
:
oauthIntentLogin
,
Identity
:
identity
,
TargetUserID
:
targetUserID
,
ResolvedEmail
:
resolvedChoiceEmail
,
RedirectTo
:
redirectTo
,
BrowserSessionKey
:
browserSessionKey
,
...
...
@@ -643,9 +648,13 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
return
}
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
client
:=
h
.
entClient
()
if
client
==
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
))
return
}
if
err
:=
ensurePendingOAuthRegistrationIdentityAvailable
(
c
.
Request
.
Context
(),
client
,
session
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
return
}
decision
,
err
:=
h
.
ensurePendingOAuthAdoptionDecision
(
c
,
session
.
ID
,
oauthAdoptionDecisionRequest
{
...
...
@@ -656,17 +665,16 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response
.
ErrorFrom
(
c
,
err
)
return
}
if
err
:=
applyPendingOAuthAdoption
(
c
.
Request
.
Context
(),
h
.
entClient
(),
h
.
authService
,
h
.
userService
,
session
,
decision
,
&
user
.
ID
);
err
!=
nil
{
response
.
ErrorFrom
(
c
,
infraerrors
.
InternalServer
(
"PENDING_AUTH_ADOPTION_APPLY_FAILED"
,
"failed to apply oauth profile adoption"
)
.
WithCause
(
err
))
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
email
,
username
,
req
.
InvitationCode
)
if
err
!=
nil
{
response
.
ErrorFrom
(
c
,
err
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
if
_
,
err
:=
pendingSvc
.
ConsumeBrowserSession
(
c
.
Request
.
Context
(),
sessionToken
,
browserSessionKey
);
err
!=
nil
{
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
response
.
ErrorFrom
(
c
,
err
)
if
err
:=
applyPendingOAuthAdoptionAndConsumeSession
(
c
.
Request
.
Context
(),
client
,
h
.
authService
,
h
.
userService
,
session
,
decision
,
user
.
ID
);
err
!=
nil
{
respondPendingOAuthBindingApplyError
(
c
,
err
)
return
}
h
.
authService
.
RecordSuccessfulLogin
(
c
.
Request
.
Context
(),
user
.
ID
)
clearOAuthPendingSessionCookie
(
c
,
secureCookie
)
clearOAuthPendingBrowserCookie
(
c
,
secureCookie
)
...
...
backend/internal/handler/auth_oidc_oauth_test.go
View file @
36aed359
...
...
@@ -438,7 +438,8 @@ func TestOIDCOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
oauthIntentLogin
,
session
.
Intent
)
require
.
Nil
(
t
,
session
.
TargetUserID
)
require
.
NotNil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
ID
,
*
session
.
TargetUserID
)
require
.
Equal
(
t
,
existingUser
.
Email
,
session
.
ResolvedEmail
)
require
.
Equal
(
t
,
"legacy@example.com"
,
session
.
UpstreamIdentityClaims
[
"compat_email"
])
...
...
@@ -862,6 +863,69 @@ func TestCompleteOIDCOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testi
require
.
False
(
t
,
decision
.
AdoptAvatar
)
}
func
TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
existingOwner
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"owner@example.com"
)
.
SetUsername
(
"owner-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
existingOwner
.
ID
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-conflict-subject"
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"oidc-complete-conflict-session"
)
.
SetIntent
(
"login"
)
.
SetProviderType
(
"oidc"
)
.
SetProviderKey
(
"https://issuer.example.com"
)
.
SetProviderSubject
(
"oidc-conflict-subject"
)
.
SetResolvedEmail
(
"f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"
)
.
SetBrowserSessionKey
(
"oidc-conflict-browser"
)
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"username"
:
"oidc_user"
,
"issuer"
:
"https://issuer.example.com"
,
})
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
10
*
time
.
Minute
))
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
body
:=
bytes
.
NewBufferString
(
`{"invitation_code":"invite-1"}`
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodPost
,
"/api/v1/auth/oauth/oidc/complete-registration"
,
body
)
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingSessionCookieName
,
Value
:
encodeCookieValue
(
session
.
SessionToken
)})
req
.
AddCookie
(
&
http
.
Cookie
{
Name
:
oauthPendingBrowserCookieName
,
Value
:
encodeCookieValue
(
"oidc-conflict-browser"
)})
c
.
Request
=
req
handler
.
CompleteOIDCOAuthRegistration
(
c
)
require
.
Equal
(
t
,
http
.
StatusConflict
,
recorder
.
Code
)
payload
:=
decodeJSONBody
(
t
,
recorder
)
require
.
Equal
(
t
,
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
payload
[
"reason"
])
userCount
,
err
:=
client
.
User
.
Query
()
.
Where
(
dbuser
.
EmailEQ
(
"f6f5f1f16f9248ccb11e0d633963b290@oidc-connect.invalid"
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
userCount
)
storedSession
,
err
:=
client
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
type
oidcProviderFixture
struct
{
Subject
string
PreferredUsername
string
...
...
backend/internal/repository/auth_identity_legacy_migration_integration_test.go
View file @
36aed359
...
...
@@ -576,6 +576,258 @@ FROM auth_identity_migration_reports
require
.
Equal
(
t
,
beforeCount
,
afterCount
)
}
func
TestAuthIdentityLegacyExternalBackfillMigration_SkipsAmbiguousCanonicalSubjects
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migrationPath
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"115_auth_identity_legacy_external_backfill.sql"
)
migrationSQL
,
err
:=
os
.
ReadFile
(
migrationPath
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoFirstUserID
))
var
linuxDoSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoSecondUserID
))
var
wechatFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatFirstUserID
))
var
wechatSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-ambiguous-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatSecondUserID
))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-a', 'Legacy LinuxDo Ambiguous A', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoFirstUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-ambiguous-subject', NULL, 'legacy-linuxdo-ambiguous-b', 'Legacy LinuxDo Ambiguous B', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoSecondUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-a', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-a', 'Legacy WeChat Ambiguous A', '{"channel":"oa","appid":"wx-ambiguous-a"}')
RETURNING id
`
,
wechatFirstUserID
)
.
Scan
(
new
(
int64
)))
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-ambiguous-b', 'union-ambiguous-subject', 'legacy-wechat-ambiguous-b', 'Legacy WeChat Ambiguous B', '{"channel":"oa","appid":"wx-ambiguous-b"}')
RETURNING id
`
,
wechatSecondUserID
)
.
Scan
(
new
(
int64
)))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migrationSQL
))
require
.
NoError
(
t
,
err
)
var
linuxDoIdentityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-ambiguous-subject'
`
)
.
Scan
(
&
linuxDoIdentityCount
))
require
.
Zero
(
t
,
linuxDoIdentityCount
)
var
wechatIdentityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-ambiguous-subject'
`
)
.
Scan
(
&
wechatIdentityCount
))
require
.
Zero
(
t
,
wechatIdentityCount
)
var
wechatChannelCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_channels
WHERE provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND channel = 'oa'
AND channel_app_id IN ('wx-ambiguous-a', 'wx-ambiguous-b')
`
)
.
Scan
(
&
wechatChannelCount
))
require
.
Zero
(
t
,
wechatChannelCount
)
}
func
TestAuthIdentityLegacyExternalMigrations_ReportAmbiguousCanonicalSubjectsWithoutWinnerAttribution
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
migration115Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"115_auth_identity_legacy_external_backfill.sql"
)
migration115SQL
,
err
:=
os
.
ReadFile
(
migration115Path
)
require
.
NoError
(
t
,
err
)
migration116Path
:=
filepath
.
Join
(
".."
,
".."
,
"migrations"
,
"116_auth_identity_legacy_external_safety_reports.sql"
)
migration116SQL
,
err
:=
os
.
ReadFile
(
migration116Path
)
require
.
NoError
(
t
,
err
)
prepareLegacyExternalIdentitiesTable
(
t
,
tx
,
ctx
)
truncateAuthIdentityLegacyFixtureTables
(
t
,
tx
,
ctx
)
var
linuxDoFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoFirstUserID
))
var
linuxDoSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
linuxDoSecondUserID
))
var
wechatFirstUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-a@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatFirstUserID
))
var
wechatSecondUserID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-conflict-b@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`
)
.
Scan
(
&
wechatSecondUserID
))
var
linuxDoFirstLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-a', 'Legacy LinuxDo Conflict A', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoFirstUserID
)
.
Scan
(
&
linuxDoFirstLegacyID
))
var
linuxDoSecondLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict-subject', NULL, 'legacy-linuxdo-conflict-b', 'Legacy LinuxDo Conflict B', '{"source":"legacy"}')
RETURNING id
`
,
linuxDoSecondUserID
)
.
Scan
(
&
linuxDoSecondLegacyID
))
var
wechatFirstLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-a', 'union-conflict-subject', 'legacy-wechat-conflict-a', 'Legacy WeChat Conflict A', '{"channel":"oa","appid":"wx-conflict-a"}')
RETURNING id
`
,
wechatFirstUserID
)
.
Scan
(
&
wechatFirstLegacyID
))
var
wechatSecondLegacyID
int64
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-conflict-b', 'union-conflict-subject', 'legacy-wechat-conflict-b', 'Legacy WeChat Conflict B', '{"channel":"oa","appid":"wx-conflict-b"}')
RETURNING id
`
,
wechatSecondUserID
)
.
Scan
(
&
wechatSecondLegacyID
))
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration115SQL
))
require
.
NoError
(
t
,
err
)
_
,
err
=
tx
.
ExecContext
(
ctx
,
string
(
migration116SQL
))
require
.
NoError
(
t
,
err
)
var
identityCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identities
WHERE (provider_type = 'linuxdo' AND provider_key = 'linuxdo' AND provider_subject = 'linuxdo-conflict-subject')
OR (provider_type = 'wechat' AND provider_key = 'wechat-main' AND provider_subject = 'union-conflict-subject')
`
)
.
Scan
(
&
identityCount
))
require
.
Zero
(
t
,
identityCount
)
var
conflictReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoSecondLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatSecondLegacyID
,
10
))
.
Scan
(
&
conflictReportCount
))
require
.
Equal
(
t
,
4
,
conflictReportCount
)
var
winnerAttributedReportCount
int
require
.
NoError
(
t
,
tx
.
QueryRowContext
(
ctx
,
`
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key IN ($1, $2, $3, $4)
AND details ->> 'existing_identity_id' IS NOT NULL
`
,
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
linuxDoSecondLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatFirstLegacyID
,
10
),
"legacy_external_identity:"
+
strconv
.
FormatInt
(
wechatSecondLegacyID
,
10
))
.
Scan
(
&
winnerAttributedReportCount
))
require
.
Zero
(
t
,
winnerAttributedReportCount
)
}
func
TestAuthIdentityMigrationReportTypeWideningPreflightKeeps109And116SafeBefore121
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
ctx
:=
context
.
Background
()
...
...
backend/internal/repository/migrations_runner.go
View file @
36aed359
...
...
@@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
const
migrationsAdvisoryLockID
int64
=
694208311321144027
const
migrationsLockRetryInterval
=
500
*
time
.
Millisecond
const
nonTransactionalMigrationSuffix
=
"_notx.sql"
const
paymentOrdersOutTradeNoUniqueMigration
=
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
const
paymentOrdersOutTradeNoUniqueIndex
=
"paymentorder_out_trade_no_unique"
type
migrationChecksumCompatibilityRule
struct
{
fileChecksum
string
...
...
@@ -65,9 +67,11 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
"054_drop_legacy_cache_columns.sql"
:
newMigrationChecksumCompatibilityRule
(
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d"
,
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4"
),
"061_add_usage_log_request_type.sql"
:
newMigrationChecksumCompatibilityRule
(
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c"
,
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0"
,
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3"
),
"109_auth_identity_compat_backfill.sql"
:
newMigrationChecksumCompatibilityRule
(
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3"
,
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee"
),
"115_auth_identity_legacy_external_backfill.sql"
:
newMigrationChecksumCompatibilityRule
(
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f"
,
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"
),
"116_auth_identity_legacy_external_safety_reports.sql"
:
newMigrationChecksumCompatibilityRule
(
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488"
,
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"
),
"118_wechat_dual_mode_and_auth_source_defaults.sql"
:
newMigrationChecksumCompatibilityRule
(
"b54194d7a3e4fbf710e0a3590d22a2fe7966804c487052a356e0b55f53ef96b0"
,
"e0cdf835d6c688d64100f483d31bc02ac9ebad414bf1837af239a84bf75b8227"
),
"119_enforce_payment_orders_out_trade_no_unique.sql"
:
newMigrationChecksumCompatibilityRule
(
"0bbe809ae48a9d811dabda1ba1c74955bd71c4a9cc610f9128816818dfa6c11e"
,
"ebd2c67cce0116393fb4f1b5d5116a67c6aceb73820dfb5133d1ff6f36d72d34"
),
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
newMigrationChecksumCompatibilityRule
(
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22"
,
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"
),
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
newMigrationChecksumCompatibilityRule
(
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074"
,
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61"
,
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22"
,
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"
),
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql"
:
newMigrationChecksumCompatibilityRule
(
"2ce43c2cd89e9f9e1febd34a407ed9e84d177386c5544b6f02c1f58a21129f57"
,
"6cd33422f215dcd1f486ab6f35c0ea5805d9ca69bb25906d94bc649156657145"
),
}
...
...
@@ -195,6 +199,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
}
if
nonTx
{
if
err
:=
prepareNonTransactionalMigration
(
ctx
,
db
,
name
);
err
!=
nil
{
return
fmt
.
Errorf
(
"prepare migration %s: %w"
,
name
,
err
)
}
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements
:=
splitSQLStatements
(
content
)
...
...
@@ -244,6 +252,88 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return
nil
}
func
prepareNonTransactionalMigration
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
name
string
)
error
{
switch
name
{
case
paymentOrdersOutTradeNoUniqueMigration
:
return
preparePaymentOrdersOutTradeNoUniqueMigration
(
ctx
,
db
)
default
:
return
nil
}
}
func
preparePaymentOrdersOutTradeNoUniqueMigration
(
ctx
context
.
Context
,
db
*
sql
.
DB
)
error
{
duplicates
,
err
:=
findDuplicatePaymentOrderOutTradeNos
(
ctx
,
db
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"precheck duplicate out_trade_no: %w"
,
err
)
}
if
len
(
duplicates
)
>
0
{
return
fmt
.
Errorf
(
"duplicate out_trade_no values block %s; remediate duplicates before retrying: %s"
,
paymentOrdersOutTradeNoUniqueMigration
,
strings
.
Join
(
duplicates
,
", "
),
)
}
invalid
,
err
:=
indexIsInvalid
(
ctx
,
db
,
paymentOrdersOutTradeNoUniqueIndex
)
if
err
!=
nil
{
return
fmt
.
Errorf
(
"check invalid index %s: %w"
,
paymentOrdersOutTradeNoUniqueIndex
,
err
)
}
if
!
invalid
{
return
nil
}
if
_
,
err
:=
db
.
ExecContext
(
ctx
,
fmt
.
Sprintf
(
"DROP INDEX CONCURRENTLY IF EXISTS %s"
,
paymentOrdersOutTradeNoUniqueIndex
));
err
!=
nil
{
return
fmt
.
Errorf
(
"drop invalid index %s: %w"
,
paymentOrdersOutTradeNoUniqueIndex
,
err
)
}
return
nil
}
func
findDuplicatePaymentOrderOutTradeNos
(
ctx
context
.
Context
,
db
*
sql
.
DB
)
([]
string
,
error
)
{
rows
,
err
:=
db
.
QueryContext
(
ctx
,
`
SELECT out_trade_no, COUNT(*) AS duplicate_count
FROM payment_orders
WHERE out_trade_no <> ''
GROUP BY out_trade_no
HAVING COUNT(*) > 1
ORDER BY duplicate_count DESC, out_trade_no
LIMIT 5
`
)
if
err
!=
nil
{
return
nil
,
err
}
defer
rows
.
Close
()
duplicates
:=
make
([]
string
,
0
,
5
)
for
rows
.
Next
()
{
var
outTradeNo
string
var
duplicateCount
int
if
err
:=
rows
.
Scan
(
&
outTradeNo
,
&
duplicateCount
);
err
!=
nil
{
return
nil
,
err
}
duplicates
=
append
(
duplicates
,
fmt
.
Sprintf
(
"%s (count=%d)"
,
outTradeNo
,
duplicateCount
))
}
if
err
:=
rows
.
Err
();
err
!=
nil
{
return
nil
,
err
}
return
duplicates
,
nil
}
func
indexIsInvalid
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
indexName
string
)
(
bool
,
error
)
{
var
invalid
bool
err
:=
db
.
QueryRowContext
(
ctx
,
`
SELECT EXISTS (
SELECT 1
FROM pg_class idx
JOIN pg_namespace ns ON ns.oid = idx.relnamespace
JOIN pg_index i ON i.indexrelid = idx.oid
WHERE ns.nspname = 'public'
AND idx.relname = $1
AND NOT i.indisvalid
)
`
,
indexName
)
.
Scan
(
&
invalid
)
return
invalid
,
err
}
func
ensureAtlasBaselineAligned
(
ctx
context
.
Context
,
db
*
sql
.
DB
,
fsys
fs
.
FS
)
error
{
hasLegacy
,
err
:=
tableExists
(
ctx
,
db
,
"schema_migrations"
)
if
err
!=
nil
{
...
...
backend/internal/repository/migrations_runner_checksum_test.go
View file @
36aed359
...
...
@@ -70,6 +70,24 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"115历史checksum可兼容修复后的legacy external backfill"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"115_auth_identity_legacy_external_backfill.sql"
,
"4cf39e508be9fd1a5aa41610cbbebeb80385c9adda45bf78a706de9db4f1385f"
,
"022aadd97bb53e755f0cf7a3a957e0cb1a1353b0c39ec4de3234acd2871fd04f"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"116历史checksum可兼容修复后的legacy external safety reports"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"116_auth_identity_legacy_external_safety_reports.sql"
,
"f7757bd929ac67ffb08ce69fa4cf20fad39dbff9d5a5085fb2adabb7607e5877"
,
"07edb09fa8d04ffb172b0621e3c22f4d1757d20a24ae267b3b36b087ab72d488"
,
)
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"119历史checksum可兼容占位文件"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"119_enforce_payment_orders_out_trade_no_unique.sql"
,
...
...
@@ -79,6 +97,21 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
require
.
True
(
t
,
ok
)
})
t
.
Run
(
"120多个历史checksum都可兼容新的notx修复版本"
,
func
(
t
*
testing
.
T
)
{
for
_
,
dbChecksum
:=
range
[]
string
{
"e77921f79d539bc24575cb9c16cbe566d2b23ce816190343d0a7568f6a3fcf61"
,
"707431450603e70a43ce9fbd61e0c12fa67da4875158ccefabacea069587ab22"
,
"04b082b5a239c525154fe9185d324ee2b05ff90da9297e10dba19f9be79aa59a"
,
}
{
ok
:=
isMigrationChecksumCompatible
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
,
dbChecksum
,
"34aadc0db59a4e390f92a12b73bd74642d9724f33124f73638ae00089ea5e074"
,
)
require
.
True
(
t
,
ok
)
}
})
t
.
Run
(
"119未知checksum不兼容"
,
func
(
t
*
testing
.
T
)
{
ok
:=
isMigrationChecksumCompatible
(
"119_enforce_payment_orders_out_trade_no_unique.sql"
,
...
...
backend/internal/repository/migrations_runner_extra_test.go
View file @
36aed359
...
...
@@ -96,6 +96,8 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
func
TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations
(
t
*
testing
.
T
)
{
for
_
,
name
:=
range
[]
string
{
"115_auth_identity_legacy_external_backfill.sql"
,
"116_auth_identity_legacy_external_safety_reports.sql"
,
"118_wechat_dual_mode_and_auth_source_defaults.sql"
,
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
,
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql"
,
...
...
backend/internal/repository/migrations_runner_notx_test.go
View file @
36aed359
...
...
@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
)
.
WillReturnError
(
sql
.
ErrNoRows
)
mock
.
ExpectQuery
(
"SELECT out_trade_no, COUNT
\\
(
\\
*
\\
) AS duplicate_count FROM payment_orders"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"out_trade_no"
,
"duplicate_count"
})
.
AddRow
(
"dup-out-trade-no"
,
2
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`
),
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"duplicate out_trade_no"
)
require
.
Contains
(
t
,
err
.
Error
(),
"dup-out-trade-no"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
)
.
WillReturnError
(
sql
.
ErrNoRows
)
mock
.
ExpectQuery
(
"SELECT out_trade_no, COUNT
\\
(
\\
*
\\
) AS duplicate_count FROM payment_orders"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"out_trade_no"
,
"duplicate_count"
}))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"paymentorder_out_trade_no_unique"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectExec
(
"DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"INSERT INTO schema_migrations
\\
(filename, checksum
\\
) VALUES
\\
(
\\
$1,
\\
$2
\\
)"
)
.
WithArgs
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
,
sqlmock
.
AnyArg
())
.
WillReturnResult
(
sqlmock
.
NewResult
(
1
,
1
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`
),
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_TransactionalMigration
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
...
...
backend/internal/repository/migrations_schema_integration_test.go
View file @
36aed359
...
...
@@ -93,6 +93,19 @@ func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T)
tx
:=
testTx
(
t
)
requireColumn
(
t
,
tx
,
"auth_identity_migration_reports"
,
"report_type"
,
"character varying"
,
80
,
false
)
requireColumn
(
t
,
tx
,
"users"
,
"signup_source"
,
"character varying"
,
20
,
false
)
requireColumnDefaultContains
(
t
,
tx
,
"users"
,
"signup_source"
,
"email"
)
requireConstraintDefinitionContains
(
t
,
tx
,
"users"
,
"users_signup_source_check"
,
"signup_source"
,
"'email'"
,
"'linuxdo'"
,
"'wechat'"
,
"'oidc'"
,
)
requireForeignKeyOnDelete
(
t
,
tx
,
"auth_identities"
,
"user_id"
,
"users"
,
"CASCADE"
)
requireForeignKeyOnDelete
(
t
,
tx
,
"auth_identity_channels"
,
"identity_id"
,
"auth_identities"
,
"CASCADE"
)
...
...
@@ -195,6 +208,45 @@ LIMIT 1
require
.
Equal
(
t
,
expected
,
actual
,
"unexpected ON DELETE action for %s.%s -> %s"
,
table
,
column
,
refTable
)
}
func
requireConstraintDefinitionContains
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
constraint
string
,
fragments
...
string
)
{
t
.
Helper
()
var
def
string
err
:=
tx
.
QueryRowContext
(
context
.
Background
(),
`
SELECT pg_get_constraintdef(c.oid)
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND c.conname = $2
`
,
table
,
constraint
)
.
Scan
(
&
def
)
require
.
NoError
(
t
,
err
,
"query constraint definition for %s.%s"
,
table
,
constraint
)
for
_
,
fragment
:=
range
fragments
{
require
.
Contains
(
t
,
def
,
fragment
,
"expected constraint definition for %s.%s to contain %q"
,
table
,
constraint
,
fragment
)
}
}
func
requireColumnDefaultContains
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
column
string
,
fragments
...
string
)
{
t
.
Helper
()
var
columnDefault
sql
.
NullString
err
:=
tx
.
QueryRowContext
(
context
.
Background
(),
`
SELECT column_default
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`
,
table
,
column
)
.
Scan
(
&
columnDefault
)
require
.
NoError
(
t
,
err
,
"query column_default for %s.%s"
,
table
,
column
)
require
.
True
(
t
,
columnDefault
.
Valid
,
"expected column_default for %s.%s"
,
table
,
column
)
for
_
,
fragment
:=
range
fragments
{
require
.
Contains
(
t
,
columnDefault
.
String
,
fragment
,
"expected default for %s.%s to contain %q"
,
table
,
column
,
fragment
)
}
}
func
requireColumn
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
column
,
dataType
string
,
maxLen
int
,
nullable
bool
)
{
t
.
Helper
()
...
...
backend/internal/repository/user_profile_identity_repo.go
View file @
36aed359
...
...
@@ -4,11 +4,15 @@ import (
"context"
"database/sql"
"fmt"
"hash/fnv"
"reflect"
"sort"
"strings"
"sync"
"time"
"unsafe"
"entgo.io/ent/dialect"
entsql
"entgo.io/ent/dialect/sql"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
...
...
@@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
}
var
repositoryScopedKeyLocks
=
newScopedKeyLockRegistry
()
type
scopedKeyLockRegistry
struct
{
mu
sync
.
Mutex
locks
map
[
string
]
*
scopedKeyLockEntry
}
type
scopedKeyLockEntry
struct
{
mu
sync
.
Mutex
refs
int
}
func
newScopedKeyLockRegistry
()
*
scopedKeyLockRegistry
{
return
&
scopedKeyLockRegistry
{
locks
:
make
(
map
[
string
]
*
scopedKeyLockEntry
),
}
}
func
(
r
*
scopedKeyLockRegistry
)
lock
(
keys
...
string
)
func
()
{
normalized
:=
normalizeLockKeys
(
keys
...
)
if
len
(
normalized
)
==
0
{
return
func
()
{}
}
entries
:=
make
([]
*
scopedKeyLockEntry
,
0
,
len
(
normalized
))
r
.
mu
.
Lock
()
for
_
,
key
:=
range
normalized
{
entry
:=
r
.
locks
[
key
]
if
entry
==
nil
{
entry
=
&
scopedKeyLockEntry
{}
r
.
locks
[
key
]
=
entry
}
entry
.
refs
++
entries
=
append
(
entries
,
entry
)
}
r
.
mu
.
Unlock
()
for
_
,
entry
:=
range
entries
{
entry
.
mu
.
Lock
()
}
return
func
()
{
for
i
:=
len
(
entries
)
-
1
;
i
>=
0
;
i
--
{
entries
[
i
]
.
mu
.
Unlock
()
}
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
for
idx
,
key
:=
range
normalized
{
entry
:=
entries
[
idx
]
entry
.
refs
--
if
entry
.
refs
==
0
{
delete
(
r
.
locks
,
key
)
}
}
}
}
func
normalizeLockKeys
(
keys
...
string
)
[]
string
{
if
len
(
keys
)
==
0
{
return
nil
}
deduped
:=
make
(
map
[
string
]
struct
{},
len
(
keys
))
for
_
,
key
:=
range
keys
{
trimmed
:=
strings
.
TrimSpace
(
key
)
if
trimmed
==
""
{
continue
}
deduped
[
trimmed
]
=
struct
{}{}
}
if
len
(
deduped
)
==
0
{
return
nil
}
normalized
:=
make
([]
string
,
0
,
len
(
deduped
))
for
key
:=
range
deduped
{
normalized
=
append
(
normalized
,
key
)
}
sort
.
Strings
(
normalized
)
return
normalized
}
func
advisoryLockHash
(
key
string
)
int64
{
hasher
:=
fnv
.
New64a
()
_
,
_
=
hasher
.
Write
([]
byte
(
key
))
return
int64
(
hasher
.
Sum64
())
}
func
lockRepositoryScopedKeys
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
exec
sqlQueryExecutor
,
keys
...
string
)
(
func
(),
error
)
{
release
:=
repositoryScopedKeyLocks
.
lock
(
keys
...
)
normalized
:=
normalizeLockKeys
(
keys
...
)
if
len
(
normalized
)
==
0
||
client
==
nil
||
exec
==
nil
||
client
.
Driver
()
.
Dialect
()
!=
dialect
.
Postgres
{
return
release
,
nil
}
for
_
,
key
:=
range
normalized
{
rows
,
err
:=
exec
.
QueryContext
(
ctx
,
"SELECT pg_advisory_xact_lock($1)"
,
advisoryLockHash
(
key
))
if
err
!=
nil
{
release
()
return
nil
,
err
}
_
=
rows
.
Close
()
}
return
release
,
nil
}
func
(
r
*
userRepository
)
WithUserProfileIdentityTx
(
ctx
context
.
Context
,
fn
func
(
txCtx
context
.
Context
)
error
)
error
{
if
dbent
.
TxFromContext
(
ctx
)
!=
nil
{
return
fn
(
ctx
)
...
...
@@ -329,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return
err
}
}
else
{
targetProviderKey
:=
canonicalizeCompatibleIdentityProviderKey
(
canonical
.
ProviderType
,
identity
.
ProviderKey
,
canonical
.
ProviderKey
)
update
:=
client
.
AuthIdentity
.
UpdateOneID
(
identity
.
ID
)
if
targetProviderKey
!=
""
&&
!
strings
.
EqualFold
(
targetProviderKey
,
identity
.
ProviderKey
)
{
update
=
update
.
SetProviderKey
(
targetProviderKey
)
}
if
input
.
Metadata
!=
nil
{
update
=
update
.
SetMetadata
(
copyMetadata
(
input
.
Metadata
))
}
...
...
@@ -378,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return
err
}
}
else
{
targetProviderKey
:=
canonicalizeCompatibleIdentityProviderKey
(
input
.
Channel
.
ProviderType
,
channel
.
ProviderKey
,
input
.
Channel
.
ProviderKey
)
update
:=
client
.
AuthIdentityChannel
.
UpdateOneID
(
channel
.
ID
)
.
SetIdentityID
(
identity
.
ID
)
if
targetProviderKey
!=
""
&&
!
strings
.
EqualFold
(
targetProviderKey
,
channel
.
ProviderKey
)
{
update
=
update
.
SetProviderKey
(
targetProviderKey
)
}
if
input
.
ChannelMetadata
!=
nil
{
update
=
update
.
SetMetadata
(
copyMetadata
(
input
.
ChannelMetadata
))
}
...
...
@@ -418,13 +537,52 @@ func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
return
keys
}
func
canonicalizeCompatibleIdentityProviderKey
(
providerType
,
existingKey
,
requestedKey
string
)
string
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
existingKey
=
strings
.
TrimSpace
(
existingKey
)
requestedKey
=
strings
.
TrimSpace
(
requestedKey
)
if
providerType
!=
"wechat"
{
if
requestedKey
!=
""
{
return
requestedKey
}
return
existingKey
}
if
strings
.
EqualFold
(
existingKey
,
"wechat"
)
||
strings
.
EqualFold
(
existingKey
,
"wechat-main"
)
||
strings
.
EqualFold
(
requestedKey
,
"wechat-main"
)
{
return
"wechat-main"
}
if
requestedKey
!=
""
{
return
requestedKey
}
return
existingKey
}
func
compatibleIdentityProviderKeyRank
(
providerType
,
providerKey
string
)
int
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
providerKey
=
strings
.
TrimSpace
(
providerKey
)
if
providerType
!=
"wechat"
{
return
0
}
switch
{
case
strings
.
EqualFold
(
providerKey
,
"wechat-main"
)
:
return
0
case
strings
.
EqualFold
(
providerKey
,
"wechat"
)
:
return
2
default
:
return
1
}
}
func
selectOwnedCompatibleIdentity
(
records
[]
*
dbent
.
AuthIdentity
,
userID
int64
)
*
dbent
.
AuthIdentity
{
var
selected
*
dbent
.
AuthIdentity
for
_
,
record
:=
range
records
{
if
record
.
UserID
==
userID
{
return
record
if
record
.
UserID
!=
userID
{
continue
}
if
selected
==
nil
||
compatibleIdentityProviderKeyRank
(
record
.
ProviderType
,
record
.
ProviderKey
)
<
compatibleIdentityProviderKeyRank
(
selected
.
ProviderType
,
selected
.
ProviderKey
)
{
selected
=
record
}
}
return
nil
return
selected
}
func
hasCompatibleIdentityConflict
(
records
[]
*
dbent
.
AuthIdentity
,
userID
int64
)
bool
{
...
...
@@ -437,12 +595,16 @@ func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64)
}
func
selectOwnedCompatibleChannel
(
records
[]
*
dbent
.
AuthIdentityChannel
,
userID
int64
)
*
dbent
.
AuthIdentityChannel
{
var
selected
*
dbent
.
AuthIdentityChannel
for
_
,
record
:=
range
records
{
if
record
.
Edges
.
Identity
!=
nil
&&
record
.
Edges
.
Identity
.
UserID
==
userID
{
return
record
if
record
.
Edges
.
Identity
==
nil
||
record
.
Edges
.
Identity
.
UserID
!=
userID
{
continue
}
if
selected
==
nil
||
compatibleIdentityProviderKeyRank
(
record
.
ProviderType
,
record
.
ProviderKey
)
<
compatibleIdentityProviderKeyRank
(
selected
.
ProviderType
,
selected
.
ProviderKey
)
{
selected
=
record
}
}
return
nil
return
selected
}
func
hasCompatibleChannelConflict
(
records
[]
*
dbent
.
AuthIdentityChannel
,
userID
int64
)
bool
{
...
...
@@ -479,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
}
func
(
r
*
userRepository
)
UpsertIdentityAdoptionDecision
(
ctx
context
.
Context
,
input
IdentityAdoptionDecisionInput
)
(
*
dbent
.
IdentityAdoptionDecision
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
if
_
,
err
:=
client
.
IdentityAdoptionDecision
.
Update
()
.
Where
(
identityadoptiondecision
.
IdentityIDEQ
(
*
input
.
IdentityID
),
dbpredicate
.
IdentityAdoptionDecision
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
identityadoptiondecision
.
FieldPendingAuthSessionID
)
s
.
Where
(
entsql
.
Or
(
entsql
.
IsNull
(
col
),
entsql
.
NEQ
(
col
,
input
.
PendingAuthSessionID
),
))
}),
)
.
ClearIdentityID
()
.
Save
(
ctx
);
err
!=
nil
{
return
nil
,
err
var
result
*
dbent
.
IdentityAdoptionDecision
err
:=
r
.
WithUserProfileIdentityTx
(
ctx
,
func
(
txCtx
context
.
Context
)
error
{
client
:=
clientFromContext
(
txCtx
,
r
.
client
)
releaseLocks
,
err
:=
lockRepositoryScopedKeys
(
txCtx
,
client
,
txAwareSQLExecutor
(
txCtx
,
r
.
sql
,
r
.
client
),
identityAdoptionDecisionLockKeys
(
input
.
PendingAuthSessionID
,
input
.
IdentityID
)
...
,
)
if
err
!=
nil
{
return
err
}
defer
releaseLocks
()
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
if
_
,
err
:=
client
.
IdentityAdoptionDecision
.
Update
()
.
Where
(
identityadoptiondecision
.
IdentityIDEQ
(
*
input
.
IdentityID
),
dbpredicate
.
IdentityAdoptionDecision
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
identityadoptiondecision
.
FieldPendingAuthSessionID
)
s
.
Where
(
entsql
.
Or
(
entsql
.
IsNull
(
col
),
entsql
.
NEQ
(
col
,
input
.
PendingAuthSessionID
),
))
}),
)
.
ClearIdentityID
()
.
Save
(
txCtx
);
err
!=
nil
{
return
err
}
}
}
current
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
input
.
PendingAuthSessionID
))
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
return
nil
,
err
}
now
:=
time
.
Now
()
.
UTC
()
if
current
==
nil
{
create
:=
client
.
IdentityAdoptionDecision
.
Create
()
.
SetPendingAuthSessionID
(
input
.
PendingAuthSessionID
)
.
SetAdoptDisplayName
(
input
.
AdoptDisplayName
)
.
SetAdoptAvatar
(
input
.
AdoptAvatar
)
.
SetDecidedAt
(
now
)
if
input
.
IdentityID
!=
nil
{
SetDecidedAt
(
time
.
Now
()
.
UTC
()
)
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
create
=
create
.
SetIdentityID
(
*
input
.
IdentityID
)
}
return
create
.
Save
(
ctx
)
decisionID
,
err
:=
create
.
OnConflictColumns
(
identityadoptiondecision
.
FieldPendingAuthSessionID
)
.
UpdateNewValues
()
.
ID
(
txCtx
)
if
err
!=
nil
{
return
err
}
result
,
err
=
client
.
IdentityAdoptionDecision
.
Get
(
txCtx
,
decisionID
)
return
err
})
if
err
!=
nil
{
return
nil
,
err
}
return
result
,
nil
}
update
:=
client
.
IdentityAdoptionDecision
.
UpdateOneID
(
current
.
ID
)
.
SetAdoptDisplayName
(
input
.
AdoptDisplayName
)
.
SetAdoptAvatar
(
input
.
AdoptAvatar
)
if
input
.
IdentityID
!=
nil
{
update
=
update
.
SetIdentityID
(
*
input
.
IdentityID
)
func
identityAdoptionDecisionLockKeys
(
pendingAuthSessionID
int64
,
identityID
*
int64
)
[]
string
{
keys
:=
[]
string
{
fmt
.
Sprintf
(
"identity-adoption:pending:%d"
,
pendingAuthSessionID
)}
if
identityID
!=
nil
&&
*
identityID
>
0
{
keys
=
append
(
keys
,
fmt
.
Sprintf
(
"identity-adoption:identity:%d"
,
*
identityID
))
}
return
update
.
Save
(
ctx
)
return
keys
}
func
(
r
*
userRepository
)
GetIdentityAdoptionDecisionByPendingAuthSessionID
(
ctx
context
.
Context
,
pendingAuthSessionID
int64
)
(
*
dbent
.
IdentityAdoptionDecision
,
error
)
{
...
...
backend/internal/repository/user_profile_identity_repo_unit_test.go
0 → 100644
View file @
36aed359
package
repository
import
(
"context"
"sync"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
user
:=
&
service
.
User
{
Email
:
"wechat-legacy@example.com"
,
Username
:
"wechat-legacy"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
user
))
legacyIdentity
,
err
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetProviderSubject
(
"union-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"source"
:
"legacy-alias"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
legacyChannel
,
err
:=
client
.
AuthIdentityChannel
.
Create
()
.
SetIdentityID
(
legacyIdentity
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetChannel
(
"oa"
)
.
SetChannelAppID
(
"wx-app-legacy"
)
.
SetChannelSubject
(
"openid-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"scene"
:
"legacy-alias"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
bound
,
err
:=
repo
.
BindAuthIdentityToUser
(
ctx
,
BindAuthIdentityInput
{
UserID
:
user
.
ID
,
Canonical
:
AuthIdentityKey
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
ProviderSubject
:
"union-legacy-123"
,
},
Channel
:
&
AuthIdentityChannelKey
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
Channel
:
"oa"
,
ChannelAppID
:
"wx-app-legacy"
,
ChannelSubject
:
"openid-legacy-123"
,
},
Metadata
:
map
[
string
]
any
{
"source"
:
"canonical-bind"
},
ChannelMetadata
:
map
[
string
]
any
{
"scene"
:
"canonical-bind"
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
bound
)
require
.
NotNil
(
t
,
bound
.
Identity
)
require
.
NotNil
(
t
,
bound
.
Channel
)
require
.
Equal
(
t
,
legacyIdentity
.
ID
,
bound
.
Identity
.
ID
)
require
.
Equal
(
t
,
legacyChannel
.
ID
,
bound
.
Channel
.
ID
)
require
.
Equal
(
t
,
"wechat-main"
,
bound
.
Identity
.
ProviderKey
)
require
.
Equal
(
t
,
"wechat-main"
,
bound
.
Channel
.
ProviderKey
)
reloadedIdentity
,
err
:=
client
.
AuthIdentity
.
Get
(
ctx
,
legacyIdentity
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"wechat-main"
,
reloadedIdentity
.
ProviderKey
)
require
.
Equal
(
t
,
"canonical-bind"
,
reloadedIdentity
.
Metadata
[
"source"
])
reloadedChannel
,
err
:=
client
.
AuthIdentityChannel
.
Get
(
ctx
,
legacyChannel
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"wechat-main"
,
reloadedChannel
.
ProviderKey
)
require
.
Equal
(
t
,
"canonical-bind"
,
reloadedChannel
.
Metadata
[
"scene"
])
identityCount
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
UserIDEQ
(
user
.
ID
),
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderSubjectEQ
(
"union-legacy-123"
),
)
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
identityCount
)
channelCount
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ChannelEQ
(
"oa"
),
authidentitychannel
.
ChannelAppIDEQ
(
"wx-app-legacy"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-legacy-123"
),
)
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
channelCount
)
}
func
TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
user
:=
&
service
.
User
{
Email
:
"repo-adoption@example.com"
,
Username
:
"repo-adoption"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
user
))
identity
,
err
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat-main"
)
.
SetProviderSubject
(
"union-repo-adoption"
)
.
SetMetadata
(
map
[
string
]
any
{})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"pending-repo-adoption"
)
.
SetIntent
(
"bind_current_user"
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat-main"
)
.
SetProviderSubject
(
"union-repo-adoption"
)
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
15
*
time
.
Minute
))
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"provider_subject"
:
"union-repo-adoption"
})
.
SetLocalFlowState
(
map
[
string
]
any
{
"step"
:
"pending"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
firstCreateStarted
:=
make
(
chan
struct
{})
releaseFirstCreate
:=
make
(
chan
struct
{})
var
firstCreate
sync
.
Once
client
.
IdentityAdoptionDecision
.
Use
(
func
(
next
dbent
.
Mutator
)
dbent
.
Mutator
{
return
dbent
.
MutateFunc
(
func
(
ctx
context
.
Context
,
m
dbent
.
Mutation
)
(
dbent
.
Value
,
error
)
{
blocked
:=
false
if
m
.
Op
()
.
Is
(
dbent
.
OpCreate
)
{
firstCreate
.
Do
(
func
()
{
blocked
=
true
close
(
firstCreateStarted
)
})
}
if
blocked
{
<-
releaseFirstCreate
}
return
next
.
Mutate
(
ctx
,
m
)
})
})
type
adoptionResult
struct
{
decision
*
dbent
.
IdentityAdoptionDecision
err
error
}
input
:=
IdentityAdoptionDecisionInput
{
PendingAuthSessionID
:
session
.
ID
,
IdentityID
:
&
identity
.
ID
,
AdoptDisplayName
:
true
,
AdoptAvatar
:
true
,
}
results
:=
make
(
chan
adoptionResult
,
2
)
go
func
()
{
decision
,
err
:=
repo
.
UpsertIdentityAdoptionDecision
(
ctx
,
input
)
results
<-
adoptionResult
{
decision
:
decision
,
err
:
err
}
}()
<-
firstCreateStarted
go
func
()
{
decision
,
err
:=
repo
.
UpsertIdentityAdoptionDecision
(
ctx
,
input
)
results
<-
adoptionResult
{
decision
:
decision
,
err
:
err
}
}()
time
.
Sleep
(
100
*
time
.
Millisecond
)
close
(
releaseFirstCreate
)
first
:=
<-
results
second
:=
<-
results
require
.
NoError
(
t
,
first
.
err
)
require
.
NoError
(
t
,
second
.
err
)
require
.
NotNil
(
t
,
first
.
decision
)
require
.
NotNil
(
t
,
second
.
decision
)
require
.
Equal
(
t
,
first
.
decision
.
ID
,
second
.
decision
.
ID
)
count
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
loaded
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
loaded
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
loaded
.
IdentityID
)
require
.
True
(
t
,
loaded
.
AdoptDisplayName
)
require
.
True
(
t
,
loaded
.
AdoptAvatar
)
}
backend/internal/repository/user_repo.go
View file @
36aed359
...
...
@@ -43,9 +43,6 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if
userIn
==
nil
{
return
nil
}
if
err
:=
r
.
ensureNormalizedEmailAvailable
(
ctx
,
0
,
userIn
.
Email
);
err
!=
nil
{
return
err
}
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
...
...
@@ -55,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var
txClient
*
dbent
.
Client
txCtx
:=
ctx
if
err
==
nil
{
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txClient
=
tx
.
Client
()
txCtx
=
dbent
.
NewTxContext
(
ctx
,
tx
)
}
else
{
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if
existingTx
:=
dbent
.
TxFromContext
(
ctx
);
existingTx
!=
nil
{
...
...
@@ -67,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
}
releaseEmailLock
,
err
:=
lockRepositoryScopedKeys
(
txCtx
,
txClient
,
txAwareSQLExecutor
(
txCtx
,
r
.
sql
,
r
.
client
),
normalizedEmailUniquenessLockKey
(
userIn
.
Email
),
)
if
err
!=
nil
{
return
err
}
defer
releaseEmailLock
()
if
err
:=
ensureNormalizedEmailAvailableWithClient
(
txCtx
,
txClient
,
0
,
userIn
.
Email
);
err
!=
nil
{
return
err
}
created
,
err
:=
txClient
.
User
.
Create
()
.
SetEmail
(
userIn
.
Email
)
.
SetUsername
(
userIn
.
Username
)
.
...
...
@@ -79,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource
(
userSignupSourceOrDefault
(
userIn
.
SignupSource
))
.
SetNillableLastLoginAt
(
userIn
.
LastLoginAt
)
.
SetNillableLastActiveAt
(
userIn
.
LastActiveAt
)
.
Save
(
c
tx
)
Save
(
txC
tx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
}
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
c
tx
,
txClient
,
created
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
txC
tx
,
txClient
,
created
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
return
err
}
if
err
:=
ensureEmailAuthIdentityWithClient
(
c
tx
,
txClient
,
created
.
ID
,
created
.
Email
,
"user_repo_create"
);
err
!=
nil
{
if
err
:=
ensureEmailAuthIdentityWithClient
(
txC
tx
,
txClient
,
created
.
ID
,
created
.
Email
,
"user_repo_create"
);
err
!=
nil
{
return
err
}
...
...
@@ -149,9 +163,6 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if
userIn
==
nil
{
return
nil
}
if
err
:=
r
.
ensureNormalizedEmailAvailable
(
ctx
,
userIn
.
ID
,
userIn
.
Email
);
err
!=
nil
{
return
err
}
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
tx
,
err
:=
r
.
client
.
Tx
(
ctx
)
...
...
@@ -160,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var
txClient
*
dbent
.
Client
txCtx
:=
ctx
if
err
==
nil
{
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txClient
=
tx
.
Client
()
txCtx
=
dbent
.
NewTxContext
(
ctx
,
tx
)
}
else
{
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if
existingTx
:=
dbent
.
TxFromContext
(
ctx
);
existingTx
!=
nil
{
...
...
@@ -171,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient
=
r
.
client
}
}
existing
,
err
:=
clientFromContext
(
ctx
,
txClient
)
.
User
.
Get
(
ctx
,
userIn
.
ID
)
releaseEmailLock
,
err
:=
lockRepositoryScopedKeys
(
txCtx
,
txClient
,
txAwareSQLExecutor
(
txCtx
,
r
.
sql
,
r
.
client
),
normalizedEmailUniquenessLockKey
(
userIn
.
Email
),
)
if
err
!=
nil
{
return
err
}
defer
releaseEmailLock
()
if
err
:=
ensureNormalizedEmailAvailableWithClient
(
txCtx
,
txClient
,
userIn
.
ID
,
userIn
.
Email
);
err
!=
nil
{
return
err
}
existing
,
err
:=
clientFromContext
(
txCtx
,
txClient
)
.
User
.
Get
(
txCtx
,
userIn
.
ID
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
...
...
@@ -203,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if
userIn
.
BalanceNotifyThreshold
==
nil
{
updateOp
=
updateOp
.
ClearBalanceNotifyThreshold
()
}
updated
,
err
:=
updateOp
.
Save
(
c
tx
)
updated
,
err
:=
updateOp
.
Save
(
txC
tx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
service
.
ErrEmailExists
)
}
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
c
tx
,
txClient
,
updated
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
txC
tx
,
txClient
,
updated
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
return
err
}
if
err
:=
replaceEmailAuthIdentityWithClient
(
c
tx
,
txClient
,
updated
.
ID
,
oldEmail
,
updated
.
Email
,
"user_repo_update"
);
err
!=
nil
{
if
err
:=
replaceEmailAuthIdentityWithClient
(
txC
tx
,
txClient
,
updated
.
ID
,
oldEmail
,
updated
.
Email
,
"user_repo_update"
);
err
!=
nil
{
return
err
}
...
...
@@ -711,7 +740,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
}
func
(
r
*
userRepository
)
ensureNormalizedEmailAvailable
(
ctx
context
.
Context
,
userID
int64
,
email
string
)
error
{
matches
,
err
:=
r
.
client
.
User
.
Query
()
.
return
ensureNormalizedEmailAvailableWithClient
(
ctx
,
clientFromContext
(
ctx
,
r
.
client
),
userID
,
email
)
}
func
ensureNormalizedEmailAvailableWithClient
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
userID
int64
,
email
string
)
error
{
client
=
clientFromContext
(
ctx
,
client
)
if
client
==
nil
{
return
nil
}
matches
,
err
:=
client
.
User
.
Query
()
.
Where
(
userEmailLookupPredicate
(
email
))
.
All
(
ctx
)
if
err
!=
nil
{
...
...
@@ -726,7 +764,7 @@ func (r *userRepository) ensureNormalizedEmailAvailable(ctx context.Context, use
}
func
userEmailLookupPredicate
(
email
string
)
predicate
.
User
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpac
e
(
email
)
)
normalized
:=
normalizeEmailLookupValu
e
(
email
)
if
normalized
==
""
{
return
dbuser
.
EmailEQ
(
email
)
}
...
...
@@ -740,6 +778,18 @@ func userEmailLookupPredicate(email string) predicate.User {
})
}
func
normalizeEmailLookupValue
(
email
string
)
string
{
return
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
}
func
normalizedEmailUniquenessLockKey
(
email
string
)
string
{
normalized
:=
normalizeEmailLookupValue
(
email
)
if
normalized
==
""
{
return
""
}
return
"users:normalized-email:"
+
normalized
}
func
(
r
*
userRepository
)
AddGroupToAllowedGroups
(
ctx
context
.
Context
,
userID
int64
,
groupID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
err
:=
client
.
UserAllowedGroup
.
Create
()
.
...
...
@@ -874,11 +924,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
}
func
userSignupSourceOrDefault
(
signupSource
string
)
string
{
signupSource
=
strings
.
TrimSpace
(
signupSource
)
if
signupSource
==
""
{
switch
strings
.
TrimSpace
(
strings
.
ToLower
(
signupSource
))
{
case
""
,
"email"
:
return
"email"
case
"linuxdo"
,
"wechat"
,
"oidc"
:
return
strings
.
TrimSpace
(
strings
.
ToLower
(
signupSource
))
default
:
return
"email"
}
return
signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment