Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
5adefb46
Commit
5adefb46
authored
Apr 20, 2026
by
IanShaw027
Browse files
fix: finalize oauth identity bindings
parent
bdcd3d87
Changes
4
Hide whitespace changes
Inline
Side-by-side
backend/internal/handler/auth_oauth_pending_flow.go
View file @
5adefb46
...
@@ -10,6 +10,7 @@ import (
...
@@ -10,6 +10,7 @@ import (
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
dbuser
"github.com/Wei-Shaw/sub2api/ent/user"
dbuser
"github.com/Wei-Shaw/sub2api/ent/user"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
...
@@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any {
...
@@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any {
return
cloned
return
cloned
}
}
func
mergeOAuthMetadata
(
base
map
[
string
]
any
,
overlay
map
[
string
]
any
)
map
[
string
]
any
{
merged
:=
cloneOAuthMetadata
(
base
)
for
key
,
value
:=
range
overlay
{
merged
[
key
]
=
value
}
return
merged
}
func
normalizeAdoptedOAuthDisplayName
(
value
string
)
string
{
func
normalizeAdoptedOAuthDisplayName
(
value
string
)
string
{
value
=
strings
.
TrimSpace
(
value
)
value
=
strings
.
TrimSpace
(
value
)
if
len
([]
rune
(
value
))
>
100
{
if
len
([]
rune
(
value
))
>
100
{
...
@@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
...
@@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
}
}
func
ensurePendingOAuthIdentityForUser
(
ctx
context
.
Context
,
tx
*
dbent
.
Tx
,
session
*
dbent
.
PendingAuthSession
,
userID
int64
)
(
*
dbent
.
AuthIdentity
,
error
)
{
func
ensurePendingOAuthIdentityForUser
(
ctx
context
.
Context
,
tx
*
dbent
.
Tx
,
session
*
dbent
.
PendingAuthSession
,
userID
int64
)
(
*
dbent
.
AuthIdentity
,
error
)
{
if
session
!=
nil
&&
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
ProviderType
),
"wechat"
)
{
return
ensurePendingWeChatOAuthIdentityForUser
(
ctx
,
tx
,
session
,
userID
)
}
client
:=
tx
.
Client
()
client
:=
tx
.
Client
()
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
Where
(
...
@@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio
...
@@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio
return
create
.
Save
(
ctx
)
return
create
.
Save
(
ctx
)
}
}
func
ensurePendingWeChatOAuthIdentityForUser
(
ctx
context
.
Context
,
tx
*
dbent
.
Tx
,
session
*
dbent
.
PendingAuthSession
,
userID
int64
)
(
*
dbent
.
AuthIdentity
,
error
)
{
client
:=
tx
.
Client
()
providerType
:=
strings
.
TrimSpace
(
session
.
ProviderType
)
providerKey
:=
strings
.
TrimSpace
(
session
.
ProviderKey
)
providerSubject
:=
strings
.
TrimSpace
(
session
.
ProviderSubject
)
channel
:=
strings
.
TrimSpace
(
pendingSessionStringValue
(
session
.
UpstreamIdentityClaims
,
"channel"
))
channelAppID
:=
strings
.
TrimSpace
(
pendingSessionStringValue
(
session
.
UpstreamIdentityClaims
,
"channel_app_id"
))
channelSubject
:=
strings
.
TrimSpace
(
pendingSessionStringValue
(
session
.
UpstreamIdentityClaims
,
"channel_subject"
))
metadata
:=
cloneOAuthMetadata
(
session
.
UpstreamIdentityClaims
)
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
providerType
),
authidentity
.
ProviderKeyEQ
(
providerKey
),
authidentity
.
ProviderSubjectEQ
(
providerSubject
),
)
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
return
nil
,
err
}
if
identity
!=
nil
&&
identity
.
UserID
!=
userID
{
return
nil
,
infraerrors
.
Conflict
(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
"auth identity already belongs to another user"
)
}
var
legacyOpenIDIdentity
*
dbent
.
AuthIdentity
if
channelSubject
!=
""
&&
channelSubject
!=
providerSubject
{
legacyOpenIDIdentity
,
err
=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
providerType
),
authidentity
.
ProviderKeyEQ
(
providerKey
),
authidentity
.
ProviderSubjectEQ
(
channelSubject
),
)
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
return
nil
,
err
}
if
legacyOpenIDIdentity
!=
nil
&&
legacyOpenIDIdentity
.
UserID
!=
userID
{
return
nil
,
infraerrors
.
Conflict
(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
"auth identity already belongs to another user"
)
}
}
switch
{
case
identity
!=
nil
:
update
:=
client
.
AuthIdentity
.
UpdateOneID
(
identity
.
ID
)
.
SetMetadata
(
mergeOAuthMetadata
(
identity
.
Metadata
,
metadata
))
if
issuer
:=
oauthIdentityIssuer
(
session
);
issuer
!=
nil
{
update
=
update
.
SetIssuer
(
strings
.
TrimSpace
(
*
issuer
))
}
identity
,
err
=
update
.
Save
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
case
legacyOpenIDIdentity
!=
nil
:
update
:=
client
.
AuthIdentity
.
UpdateOneID
(
legacyOpenIDIdentity
.
ID
)
.
SetProviderSubject
(
providerSubject
)
.
SetMetadata
(
mergeOAuthMetadata
(
legacyOpenIDIdentity
.
Metadata
,
metadata
))
if
issuer
:=
oauthIdentityIssuer
(
session
);
issuer
!=
nil
{
update
=
update
.
SetIssuer
(
strings
.
TrimSpace
(
*
issuer
))
}
identity
,
err
=
update
.
Save
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
default
:
create
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
userID
)
.
SetProviderType
(
providerType
)
.
SetProviderKey
(
providerKey
)
.
SetProviderSubject
(
providerSubject
)
.
SetMetadata
(
metadata
)
if
issuer
:=
oauthIdentityIssuer
(
session
);
issuer
!=
nil
{
create
=
create
.
SetIssuer
(
strings
.
TrimSpace
(
*
issuer
))
}
identity
,
err
=
create
.
Save
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
}
if
channel
==
""
||
channelAppID
==
""
||
channelSubject
==
""
{
return
identity
,
nil
}
channelRecord
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
providerType
),
authidentitychannel
.
ProviderKeyEQ
(
providerKey
),
authidentitychannel
.
ChannelEQ
(
channel
),
authidentitychannel
.
ChannelAppIDEQ
(
channelAppID
),
authidentitychannel
.
ChannelSubjectEQ
(
channelSubject
),
)
.
WithIdentity
()
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
return
nil
,
err
}
if
channelRecord
!=
nil
&&
channelRecord
.
Edges
.
Identity
!=
nil
&&
channelRecord
.
Edges
.
Identity
.
UserID
!=
userID
{
return
nil
,
infraerrors
.
Conflict
(
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT"
,
"auth identity channel already belongs to another user"
)
}
channelMetadata
:=
mergeOAuthMetadata
(
channelRecordMetadata
(
channelRecord
),
metadata
)
if
channelRecord
==
nil
{
if
_
,
err
:=
client
.
AuthIdentityChannel
.
Create
()
.
SetIdentityID
(
identity
.
ID
)
.
SetProviderType
(
providerType
)
.
SetProviderKey
(
providerKey
)
.
SetChannel
(
channel
)
.
SetChannelAppID
(
channelAppID
)
.
SetChannelSubject
(
channelSubject
)
.
SetMetadata
(
channelMetadata
)
.
Save
(
ctx
);
err
!=
nil
{
return
nil
,
err
}
return
identity
,
nil
}
_
,
err
=
client
.
AuthIdentityChannel
.
UpdateOneID
(
channelRecord
.
ID
)
.
SetIdentityID
(
identity
.
ID
)
.
SetMetadata
(
channelMetadata
)
.
Save
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
return
identity
,
nil
}
func
channelRecordMetadata
(
channel
*
dbent
.
AuthIdentityChannel
)
map
[
string
]
any
{
if
channel
==
nil
{
return
map
[
string
]
any
{}
}
return
cloneOAuthMetadata
(
channel
.
Metadata
)
}
func
shouldBindPendingOAuthIdentity
(
session
*
dbent
.
PendingAuthSession
,
decision
*
dbent
.
IdentityAdoptionDecision
)
bool
{
func
shouldBindPendingOAuthIdentity
(
session
*
dbent
.
PendingAuthSession
,
decision
*
dbent
.
IdentityAdoptionDecision
)
bool
{
if
session
==
nil
||
decision
==
nil
{
if
session
==
nil
||
decision
==
nil
{
return
false
return
false
}
}
if
strings
.
EqualFold
(
strings
.
TrimSpace
(
session
.
Intent
),
"bind_current_user"
)
{
switch
strings
.
ToLower
(
strings
.
TrimSpace
(
session
.
Intent
))
{
case
"bind_current_user"
,
"login"
,
"adopt_existing_user_by_email"
:
return
true
return
true
default
:
return
decision
.
AdoptDisplayName
||
decision
.
AdoptAvatar
}
}
return
decision
.
AdoptDisplayName
||
decision
.
AdoptAvatar
}
}
func
applyPendingOAuthBinding
(
func
applyPendingOAuthBinding
(
...
...
backend/internal/handler/auth_oauth_pending_flow_test.go
View file @
5adefb46
...
@@ -372,7 +372,7 @@ func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testi
...
@@ -372,7 +372,7 @@ func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testi
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
require
.
Nil
(
t
,
storedSession
.
ConsumedAt
)
}
}
func
TestExchangePendingOAuthCompletionLoginFalseFalse
DoesNot
BindIdentity
(
t
*
testing
.
T
)
{
func
TestExchangePendingOAuthCompletionLoginFalseFalseBind
s
Identity
WithoutAdoption
(
t
*
testing
.
T
)
{
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
handler
,
client
:=
newOAuthPendingFlowTestHandler
(
t
,
false
)
ctx
:=
context
.
Background
()
ctx
:=
context
.
Background
()
...
@@ -420,21 +420,22 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *tes
...
@@ -420,21 +420,22 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *tes
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
require
.
Equal
(
t
,
http
.
StatusOK
,
recorder
.
Code
)
identity
Count
,
err
:=
client
.
AuthIdentity
.
Query
()
.
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
Where
(
authidentity
.
ProviderTypeEQ
(
"linuxdo"
),
authidentity
.
ProviderTypeEQ
(
"linuxdo"
),
authidentity
.
ProviderKeyEQ
(
"linuxdo"
),
authidentity
.
ProviderKeyEQ
(
"linuxdo"
),
authidentity
.
ProviderSubjectEQ
(
"login-false-123"
),
authidentity
.
ProviderSubjectEQ
(
"login-false-123"
),
)
.
)
.
Count
(
ctx
)
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
identity
Count
)
require
.
Equal
(
t
,
userEntity
.
ID
,
identity
.
UserID
)
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
err
)
require
.
Nil
(
t
,
decision
.
IdentityID
)
require
.
NotNil
(
t
,
decision
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
decision
.
IdentityID
)
require
.
False
(
t
,
decision
.
AdoptDisplayName
)
require
.
False
(
t
,
decision
.
AdoptDisplayName
)
require
.
False
(
t
,
decision
.
AdoptAvatar
)
require
.
False
(
t
,
decision
.
AdoptAvatar
)
...
...
backend/internal/handler/auth_wechat_oauth.go
View file @
5adefb46
...
@@ -242,7 +242,18 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
...
@@ -242,7 +242,18 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) {
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
return
}
}
if
existingIdentityUser
==
nil
{
existingIdentityUser
,
err
=
h
.
findWeChatUserByLegacyOpenID
(
c
.
Request
.
Context
(),
identityRef
,
cfg
,
openid
)
if
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
}
if
existingIdentityUser
!=
nil
{
if
existingIdentityUser
!=
nil
{
if
err
:=
h
.
ensureWeChatRuntimeIdentityBinding
(
c
.
Request
.
Context
(),
existingIdentityUser
.
ID
,
identityRef
,
upstreamClaims
);
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"session_error"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
return
}
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
existingIdentityUser
.
Email
,
username
,
""
)
tokenPair
,
user
,
err
:=
h
.
authService
.
LoginOrRegisterOAuthWithTokenPair
(
c
.
Request
.
Context
(),
existingIdentityUser
.
Email
,
username
,
""
)
if
err
!=
nil
{
if
err
!=
nil
{
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
redirectOAuthError
(
c
,
frontendCallback
,
"login_failed"
,
infraerrors
.
Reason
(
err
),
infraerrors
.
Message
(
err
))
...
@@ -511,6 +522,91 @@ func (h *AuthHandler) ensureWeChatBindOwnership(
...
@@ -511,6 +522,91 @@ func (h *AuthHandler) ensureWeChatBindOwnership(
return
nil
return
nil
}
}
func
(
h
*
AuthHandler
)
findWeChatUserByLegacyOpenID
(
ctx
context
.
Context
,
identity
service
.
PendingAuthIdentityKey
,
cfg
wechatOAuthConfig
,
openid
string
,
)
(
*
dbent
.
User
,
error
)
{
client
:=
h
.
entClient
()
if
client
==
nil
{
return
nil
,
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
)
}
openid
=
strings
.
TrimSpace
(
openid
)
channel
:=
strings
.
TrimSpace
(
cfg
.
mode
)
channelAppID
:=
strings
.
TrimSpace
(
cfg
.
appID
)
if
openid
!=
""
&&
channel
!=
""
&&
channelAppID
!=
""
{
record
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
strings
.
TrimSpace
(
identity
.
ProviderType
)),
authidentitychannel
.
ProviderKeyEQ
(
strings
.
TrimSpace
(
identity
.
ProviderKey
)),
authidentitychannel
.
ChannelEQ
(
channel
),
authidentitychannel
.
ChannelAppIDEQ
(
channelAppID
),
authidentitychannel
.
ChannelSubjectEQ
(
openid
),
)
.
WithIdentity
(
func
(
q
*
dbent
.
AuthIdentityQuery
)
{
q
.
WithUser
()
})
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED"
,
"failed to inspect auth identity channel ownership"
)
.
WithCause
(
err
)
}
if
record
!=
nil
&&
record
.
Edges
.
Identity
!=
nil
&&
record
.
Edges
.
Identity
.
Edges
.
User
!=
nil
{
return
record
.
Edges
.
Identity
.
Edges
.
User
,
nil
}
}
if
openid
==
""
{
return
nil
,
nil
}
record
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
strings
.
TrimSpace
(
identity
.
ProviderType
)),
authidentity
.
ProviderKeyEQ
(
strings
.
TrimSpace
(
identity
.
ProviderKey
)),
authidentity
.
ProviderSubjectEQ
(
openid
),
)
.
WithUser
()
.
Only
(
ctx
)
if
err
!=
nil
{
if
dbent
.
IsNotFound
(
err
)
{
return
nil
,
nil
}
return
nil
,
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
return
record
.
Edges
.
User
,
nil
}
func
(
h
*
AuthHandler
)
ensureWeChatRuntimeIdentityBinding
(
ctx
context
.
Context
,
userID
int64
,
identity
service
.
PendingAuthIdentityKey
,
upstreamClaims
map
[
string
]
any
,
)
error
{
client
:=
h
.
entClient
()
if
client
==
nil
{
return
infraerrors
.
ServiceUnavailable
(
"PENDING_AUTH_NOT_READY"
,
"pending auth service is not ready"
)
}
tx
,
err
:=
client
.
Tx
(
ctx
)
if
err
!=
nil
{
return
infraerrors
.
InternalServer
(
"AUTH_IDENTITY_BIND_FAILED"
,
"failed to begin wechat identity repair transaction"
)
.
WithCause
(
err
)
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
_
,
err
=
ensurePendingOAuthIdentityForUser
(
dbent
.
NewTxContext
(
ctx
,
tx
),
tx
,
&
dbent
.
PendingAuthSession
{
ProviderType
:
strings
.
TrimSpace
(
identity
.
ProviderType
),
ProviderKey
:
strings
.
TrimSpace
(
identity
.
ProviderKey
),
ProviderSubject
:
strings
.
TrimSpace
(
identity
.
ProviderSubject
),
UpstreamIdentityClaims
:
cloneOAuthMetadata
(
upstreamClaims
),
},
userID
)
if
err
!=
nil
{
return
err
}
return
tx
.
Commit
()
}
func
(
h
*
AuthHandler
)
getWeChatOAuthConfig
(
ctx
context
.
Context
,
rawMode
string
,
c
*
gin
.
Context
)
(
wechatOAuthConfig
,
error
)
{
func
(
h
*
AuthHandler
)
getWeChatOAuthConfig
(
ctx
context
.
Context
,
rawMode
string
,
c
*
gin
.
Context
)
(
wechatOAuthConfig
,
error
)
{
mode
,
err
:=
resolveWeChatOAuthMode
(
rawMode
,
c
)
mode
,
err
:=
resolveWeChatOAuthMode
(
rawMode
,
c
)
if
err
!=
nil
{
if
err
!=
nil
{
...
...
backend/internal/handler/auth_wechat_oauth_test.go
View file @
5adefb46
...
@@ -15,6 +15,7 @@ import (
...
@@ -15,6 +15,7 @@ import (
dbent
"github.com/Wei-Shaw/sub2api/ent"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
...
@@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
...
@@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require
.
Equal
(
t
,
"WeChat Display"
,
identity
.
Metadata
[
"display_name"
])
require
.
Equal
(
t
,
"WeChat Display"
,
identity
.
Metadata
[
"display_name"
])
require
.
Equal
(
t
,
"https://cdn.example/wechat.png"
,
identity
.
Metadata
[
"avatar_url"
])
require
.
Equal
(
t
,
"https://cdn.example/wechat.png"
,
identity
.
Metadata
[
"avatar_url"
])
channel
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ProviderKeyEQ
(
"wechat-main"
),
authidentitychannel
.
ChannelEQ
(
"open"
),
authidentitychannel
.
ChannelAppIDEQ
(
"wx-open-app"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-123"
),
)
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
identity
.
ID
,
channel
.
IdentityID
)
require
.
Equal
(
t
,
"union-456"
,
channel
.
Metadata
[
"unionid"
])
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
pendingSession
.
ID
))
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
pendingSession
.
ID
))
.
Only
(
ctx
)
Only
(
ctx
)
...
@@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
...
@@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing
require
.
NotNil
(
t
,
consumed
.
ConsumedAt
)
require
.
NotNil
(
t
,
consumed
.
ConsumedAt
)
}
}
func
TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity
(
t
*
testing
.
T
)
{
t
.
Setenv
(
"WECHAT_OAUTH_OPEN_APP_ID"
,
"wx-open-app"
)
t
.
Setenv
(
"WECHAT_OAUTH_OPEN_APP_SECRET"
,
"wx-open-secret"
)
t
.
Setenv
(
"WECHAT_OAUTH_FRONTEND_REDIRECT_URL"
,
"/auth/wechat/callback"
)
originalAccessTokenURL
:=
wechatOAuthAccessTokenURL
originalUserInfoURL
:=
wechatOAuthUserInfoURL
t
.
Cleanup
(
func
()
{
wechatOAuthAccessTokenURL
=
originalAccessTokenURL
wechatOAuthUserInfoURL
=
originalUserInfoURL
})
upstream
:=
httptest
.
NewServer
(
http
.
HandlerFunc
(
func
(
w
http
.
ResponseWriter
,
r
*
http
.
Request
)
{
switch
{
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/oauth2/access_token"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`
))
case
strings
.
Contains
(
r
.
URL
.
Path
,
"/sns/userinfo"
)
:
w
.
Header
()
.
Set
(
"Content-Type"
,
"application/json"
)
_
,
_
=
w
.
Write
([]
byte
(
`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`
))
default
:
http
.
NotFound
(
w
,
r
)
}
}))
defer
upstream
.
Close
()
wechatOAuthAccessTokenURL
=
upstream
.
URL
+
"/sns/oauth2/access_token"
wechatOAuthUserInfoURL
=
upstream
.
URL
+
"/sns/userinfo"
handler
,
client
:=
newWeChatOAuthTestHandler
(
t
,
false
)
defer
client
.
Close
()
ctx
:=
context
.
Background
()
legacyUser
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"legacy@example.com"
)
.
SetUsername
(
"legacy-user"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
legacyIdentity
,
err
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
legacyUser
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
wechatOAuthProviderKey
)
.
SetProviderSubject
(
"openid-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"openid"
:
"openid-123"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
recorder
:=
httptest
.
NewRecorder
()
c
,
_
:=
gin
.
CreateTestContext
(
recorder
)
req
:=
httptest
.
NewRequest
(
http
.
MethodGet
,
"/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123"
,
nil
)
req
.
Host
=
"api.example.com"
req
.
AddCookie
(
encodedCookie
(
wechatOAuthStateCookieName
,
"state-123"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthRedirectCookieName
,
"/dashboard"
))
req
.
AddCookie
(
encodedCookie
(
wechatOAuthModeCookieName
,
"open"
))
req
.
AddCookie
(
encodedCookie
(
oauthPendingBrowserCookieName
,
"browser-123"
))
c
.
Request
=
req
handler
.
WeChatOAuthCallback
(
c
)
require
.
Equal
(
t
,
http
.
StatusFound
,
recorder
.
Code
)
require
.
Equal
(
t
,
"/auth/wechat/callback"
,
recorder
.
Header
()
.
Get
(
"Location"
))
sessionCookie
:=
findCookie
(
recorder
.
Result
()
.
Cookies
(),
oauthPendingSessionCookieName
)
require
.
NotNil
(
t
,
sessionCookie
)
session
,
err
:=
client
.
PendingAuthSession
.
Query
()
.
Where
(
pendingauthsession
.
SessionTokenEQ
(
decodeCookieValueForTest
(
t
,
sessionCookie
.
Value
)))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
session
.
TargetUserID
)
require
.
Equal
(
t
,
legacyUser
.
ID
,
*
session
.
TargetUserID
)
require
.
Equal
(
t
,
legacyUser
.
Email
,
session
.
ResolvedEmail
)
repairedIdentity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderKeyEQ
(
wechatOAuthProviderKey
),
authidentity
.
ProviderSubjectEQ
(
"union-456"
),
)
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
legacyIdentity
.
ID
,
repairedIdentity
.
ID
)
require
.
Equal
(
t
,
legacyUser
.
ID
,
repairedIdentity
.
UserID
)
openIDIdentityCount
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderKeyEQ
(
wechatOAuthProviderKey
),
authidentity
.
ProviderSubjectEQ
(
"openid-123"
),
)
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Zero
(
t
,
openIDIdentityCount
)
channel
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ProviderKeyEQ
(
wechatOAuthProviderKey
),
authidentitychannel
.
ChannelEQ
(
"open"
),
authidentitychannel
.
ChannelAppIDEQ
(
"wx-open-app"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-123"
),
)
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
repairedIdentity
.
ID
,
channel
.
IdentityID
)
}
func
newWeChatOAuthTestHandler
(
t
*
testing
.
T
,
invitationEnabled
bool
)
(
*
AuthHandler
,
*
dbent
.
Client
)
{
func
newWeChatOAuthTestHandler
(
t
*
testing
.
T
,
invitationEnabled
bool
)
(
*
AuthHandler
,
*
dbent
.
Client
)
{
t
.
Helper
()
t
.
Helper
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment