Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
陈曦
sub2api
Commits
ddf80f5e
Unverified
Commit
ddf80f5e
authored
Apr 22, 2026
by
Wesley Liddick
Committed by
GitHub
Apr 22, 2026
Browse files
Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation
fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents
4d0483f5
c048ca80
Changes
140
Hide whitespace changes
Inline
Side-by-side
backend/internal/repository/migrations_runner_extra_test.go
View file @
ddf80f5e
...
...
@@ -94,6 +94,24 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require
.
True
(
t
,
isMigrationChecksumCompatible
(
name
,
accepted
,
rule
.
fileChecksum
))
}
func
TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations
(
t
*
testing
.
T
)
{
for
_
,
name
:=
range
[]
string
{
"109_auth_identity_compat_backfill.sql"
,
"110_pending_auth_and_provider_default_grants.sql"
,
"112_add_payment_order_provider_key_snapshot.sql"
,
"115_auth_identity_legacy_external_backfill.sql"
,
"116_auth_identity_legacy_external_safety_reports.sql"
,
"118_wechat_dual_mode_and_auth_source_defaults.sql"
,
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
,
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql"
,
}
{
rule
,
ok
:=
migrationChecksumCompatibilityRules
[
name
]
require
.
Truef
(
t
,
ok
,
"missing compatibility rule for %s"
,
name
)
require
.
NotEmpty
(
t
,
rule
.
fileChecksum
)
require
.
NotEmpty
(
t
,
rule
.
acceptedDBChecksum
)
}
}
func
TestEnsureAtlasBaselineAligned
(
t
*
testing
.
T
)
{
t
.
Run
(
"skip_when_no_legacy_table"
,
func
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
...
...
backend/internal/repository/migrations_runner_notx_test.go
View file @
ddf80f5e
...
...
@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
)
.
WillReturnError
(
sql
.
ErrNoRows
)
mock
.
ExpectQuery
(
"SELECT out_trade_no, COUNT
\\
(
\\
*
\\
) AS duplicate_count FROM payment_orders"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"out_trade_no"
,
"duplicate_count"
})
.
AddRow
(
"dup-out-trade-no"
,
2
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`
),
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
Error
(
t
,
err
)
require
.
Contains
(
t
,
err
.
Error
(),
"duplicate out_trade_no"
)
require
.
Contains
(
t
,
err
.
Error
(),
"dup-out-trade-no"
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
defer
func
()
{
_
=
db
.
Close
()
}()
prepareMigrationsBootstrapExpectations
(
mock
)
mock
.
ExpectQuery
(
"SELECT checksum FROM schema_migrations WHERE filename =
\\
$1"
)
.
WithArgs
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
)
.
WillReturnError
(
sql
.
ErrNoRows
)
mock
.
ExpectQuery
(
"SELECT out_trade_no, COUNT
\\
(
\\
*
\\
) AS duplicate_count FROM payment_orders"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"out_trade_no"
,
"duplicate_count"
}))
mock
.
ExpectQuery
(
"SELECT EXISTS
\\
("
)
.
WithArgs
(
"paymentorder_out_trade_no_unique"
)
.
WillReturnRows
(
sqlmock
.
NewRows
([]
string
{
"exists"
})
.
AddRow
(
true
))
mock
.
ExpectExec
(
"DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no"
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
0
))
mock
.
ExpectExec
(
"INSERT INTO schema_migrations
\\
(filename, checksum
\\
) VALUES
\\
(
\\
$1,
\\
$2
\\
)"
)
.
WithArgs
(
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
,
sqlmock
.
AnyArg
())
.
WillReturnResult
(
sqlmock
.
NewResult
(
1
,
1
))
mock
.
ExpectExec
(
"SELECT pg_advisory_unlock
\\
(
\\
$1
\\
)"
)
.
WithArgs
(
migrationsAdvisoryLockID
)
.
WillReturnResult
(
sqlmock
.
NewResult
(
0
,
1
))
fsys
:=
fstest
.
MapFS
{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql"
:
&
fstest
.
MapFile
{
Data
:
[]
byte
(
`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`
),
},
}
err
=
applyMigrationsFS
(
context
.
Background
(),
db
,
fsys
)
require
.
NoError
(
t
,
err
)
require
.
NoError
(
t
,
mock
.
ExpectationsWereMet
())
}
func
TestApplyMigrationsFS_TransactionalMigration
(
t
*
testing
.
T
)
{
db
,
mock
,
err
:=
sqlmock
.
New
()
require
.
NoError
(
t
,
err
)
...
...
backend/internal/repository/migrations_schema_integration_test.go
View file @
ddf80f5e
...
...
@@ -89,6 +89,35 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn
(
t
,
tx
,
"user_allowed_groups"
,
"created_at"
,
"timestamp with time zone"
,
0
,
false
)
}
func
TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned
(
t
*
testing
.
T
)
{
tx
:=
testTx
(
t
)
requireColumn
(
t
,
tx
,
"auth_identity_migration_reports"
,
"report_type"
,
"character varying"
,
80
,
false
)
requireColumn
(
t
,
tx
,
"users"
,
"signup_source"
,
"character varying"
,
20
,
false
)
requireColumnDefaultContains
(
t
,
tx
,
"users"
,
"signup_source"
,
"email"
)
requireConstraintDefinitionContains
(
t
,
tx
,
"users"
,
"users_signup_source_check"
,
"signup_source"
,
"'email'"
,
"'linuxdo'"
,
"'wechat'"
,
"'oidc'"
,
)
requireForeignKeyOnDelete
(
t
,
tx
,
"auth_identities"
,
"user_id"
,
"users"
,
"CASCADE"
)
requireForeignKeyOnDelete
(
t
,
tx
,
"auth_identity_channels"
,
"identity_id"
,
"auth_identities"
,
"CASCADE"
)
requireForeignKeyOnDelete
(
t
,
tx
,
"pending_auth_sessions"
,
"target_user_id"
,
"users"
,
"SET NULL"
)
requireForeignKeyOnDelete
(
t
,
tx
,
"identity_adoption_decisions"
,
"pending_auth_session_id"
,
"pending_auth_sessions"
,
"CASCADE"
)
requireForeignKeyOnDelete
(
t
,
tx
,
"identity_adoption_decisions"
,
"identity_id"
,
"auth_identities"
,
"SET NULL"
)
requireIndex
(
t
,
tx
,
"payment_orders"
,
"paymentorder_out_trade_no"
)
requirePartialUniqueIndexDefinition
(
t
,
tx
,
"payment_orders"
,
"paymentorder_out_trade_no"
,
"out_trade_no"
,
"WHERE"
)
requireIndexAbsent
(
t
,
tx
,
"payment_orders"
,
"paymentorder_out_trade_no_unique"
)
}
func
requireIndex
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
index
string
)
{
t
.
Helper
()
...
...
@@ -106,6 +135,118 @@ SELECT EXISTS (
require
.
True
(
t
,
exists
,
"expected index %s on %s"
,
index
,
table
)
}
func
requireIndexAbsent
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
index
string
)
{
t
.
Helper
()
var
exists
bool
err
:=
tx
.
QueryRowContext
(
context
.
Background
(),
`
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`
,
table
,
index
)
.
Scan
(
&
exists
)
require
.
NoError
(
t
,
err
,
"query pg_indexes for %s.%s"
,
table
,
index
)
require
.
False
(
t
,
exists
,
"expected index %s on %s to be absent"
,
index
,
table
)
}
func
requirePartialUniqueIndexDefinition
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
index
string
,
fragments
...
string
)
{
t
.
Helper
()
var
(
unique
bool
def
string
)
err
:=
tx
.
QueryRowContext
(
context
.
Background
(),
`
SELECT
i.indisunique,
pg_get_indexdef(i.indexrelid)
FROM pg_class idx
JOIN pg_index i ON i.indexrelid = idx.oid
JOIN pg_class tbl ON tbl.oid = i.indrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND idx.relname = $2
`
,
table
,
index
)
.
Scan
(
&
unique
,
&
def
)
require
.
NoError
(
t
,
err
,
"query index definition for %s.%s"
,
table
,
index
)
require
.
True
(
t
,
unique
,
"expected index %s on %s to be unique"
,
index
,
table
)
for
_
,
fragment
:=
range
fragments
{
require
.
Contains
(
t
,
def
,
fragment
,
"expected index definition for %s.%s to contain %q"
,
table
,
index
,
fragment
)
}
}
func
requireForeignKeyOnDelete
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
column
,
refTable
,
expected
string
)
{
t
.
Helper
()
var
actual
string
err
:=
tx
.
QueryRowContext
(
context
.
Background
(),
`
SELECT CASE c.confdeltype
WHEN 'a' THEN 'NO ACTION'
WHEN 'r' THEN 'RESTRICT'
WHEN 'c' THEN 'CASCADE'
WHEN 'n' THEN 'SET NULL'
WHEN 'd' THEN 'SET DEFAULT'
END
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
WHERE ns.nspname = 'public'
AND c.contype = 'f'
AND tbl.relname = $1
AND attr.attname = $2
AND ref_tbl.relname = $3
LIMIT 1
`
,
table
,
column
,
refTable
)
.
Scan
(
&
actual
)
require
.
NoError
(
t
,
err
,
"query foreign key action for %s.%s -> %s"
,
table
,
column
,
refTable
)
require
.
Equal
(
t
,
expected
,
actual
,
"unexpected ON DELETE action for %s.%s -> %s"
,
table
,
column
,
refTable
)
}
func
requireConstraintDefinitionContains
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
constraint
string
,
fragments
...
string
)
{
t
.
Helper
()
var
def
string
err
:=
tx
.
QueryRowContext
(
context
.
Background
(),
`
SELECT pg_get_constraintdef(c.oid)
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND c.conname = $2
`
,
table
,
constraint
)
.
Scan
(
&
def
)
require
.
NoError
(
t
,
err
,
"query constraint definition for %s.%s"
,
table
,
constraint
)
for
_
,
fragment
:=
range
fragments
{
require
.
Contains
(
t
,
def
,
fragment
,
"expected constraint definition for %s.%s to contain %q"
,
table
,
constraint
,
fragment
)
}
}
func
requireColumnDefaultContains
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
column
string
,
fragments
...
string
)
{
t
.
Helper
()
var
columnDefault
sql
.
NullString
err
:=
tx
.
QueryRowContext
(
context
.
Background
(),
`
SELECT column_default
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`
,
table
,
column
)
.
Scan
(
&
columnDefault
)
require
.
NoError
(
t
,
err
,
"query column_default for %s.%s"
,
table
,
column
)
require
.
True
(
t
,
columnDefault
.
Valid
,
"expected column_default for %s.%s"
,
table
,
column
)
for
_
,
fragment
:=
range
fragments
{
require
.
Contains
(
t
,
columnDefault
.
String
,
fragment
,
"expected default for %s.%s to contain %q"
,
table
,
column
,
fragment
)
}
}
func
requireColumn
(
t
*
testing
.
T
,
tx
*
sql
.
Tx
,
table
,
column
,
dataType
string
,
maxLen
int
,
nullable
bool
)
{
t
.
Helper
()
...
...
backend/internal/repository/user_profile_identity_repo.go
View file @
ddf80f5e
...
...
@@ -4,11 +4,15 @@ import (
"context"
"database/sql"
"fmt"
"hash/fnv"
"reflect"
"sort"
"strings"
"sync"
"time"
"unsafe"
"entgo.io/ent/dialect"
entsql
"entgo.io/ent/dialect/sql"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
...
...
@@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
QueryContext
(
ctx
context
.
Context
,
query
string
,
args
...
any
)
(
*
sql
.
Rows
,
error
)
}
var
repositoryScopedKeyLocks
=
newScopedKeyLockRegistry
()
type
scopedKeyLockRegistry
struct
{
mu
sync
.
Mutex
locks
map
[
string
]
*
scopedKeyLockEntry
}
type
scopedKeyLockEntry
struct
{
mu
sync
.
Mutex
refs
int
}
func
newScopedKeyLockRegistry
()
*
scopedKeyLockRegistry
{
return
&
scopedKeyLockRegistry
{
locks
:
make
(
map
[
string
]
*
scopedKeyLockEntry
),
}
}
func
(
r
*
scopedKeyLockRegistry
)
lock
(
keys
...
string
)
func
()
{
normalized
:=
normalizeLockKeys
(
keys
...
)
if
len
(
normalized
)
==
0
{
return
func
()
{}
}
entries
:=
make
([]
*
scopedKeyLockEntry
,
0
,
len
(
normalized
))
r
.
mu
.
Lock
()
for
_
,
key
:=
range
normalized
{
entry
:=
r
.
locks
[
key
]
if
entry
==
nil
{
entry
=
&
scopedKeyLockEntry
{}
r
.
locks
[
key
]
=
entry
}
entry
.
refs
++
entries
=
append
(
entries
,
entry
)
}
r
.
mu
.
Unlock
()
for
_
,
entry
:=
range
entries
{
entry
.
mu
.
Lock
()
}
return
func
()
{
for
i
:=
len
(
entries
)
-
1
;
i
>=
0
;
i
--
{
entries
[
i
]
.
mu
.
Unlock
()
}
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
for
idx
,
key
:=
range
normalized
{
entry
:=
entries
[
idx
]
entry
.
refs
--
if
entry
.
refs
==
0
{
delete
(
r
.
locks
,
key
)
}
}
}
}
func
normalizeLockKeys
(
keys
...
string
)
[]
string
{
if
len
(
keys
)
==
0
{
return
nil
}
deduped
:=
make
(
map
[
string
]
struct
{},
len
(
keys
))
for
_
,
key
:=
range
keys
{
trimmed
:=
strings
.
TrimSpace
(
key
)
if
trimmed
==
""
{
continue
}
deduped
[
trimmed
]
=
struct
{}{}
}
if
len
(
deduped
)
==
0
{
return
nil
}
normalized
:=
make
([]
string
,
0
,
len
(
deduped
))
for
key
:=
range
deduped
{
normalized
=
append
(
normalized
,
key
)
}
sort
.
Strings
(
normalized
)
return
normalized
}
func
advisoryLockHash
(
key
string
)
int64
{
hasher
:=
fnv
.
New64a
()
_
,
_
=
hasher
.
Write
([]
byte
(
key
))
return
int64
(
hasher
.
Sum64
())
}
func
lockRepositoryScopedKeys
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
exec
sqlQueryExecutor
,
keys
...
string
)
(
func
(),
error
)
{
release
:=
repositoryScopedKeyLocks
.
lock
(
keys
...
)
normalized
:=
normalizeLockKeys
(
keys
...
)
if
len
(
normalized
)
==
0
||
client
==
nil
||
exec
==
nil
||
client
.
Driver
()
.
Dialect
()
!=
dialect
.
Postgres
{
return
release
,
nil
}
for
_
,
key
:=
range
normalized
{
rows
,
err
:=
exec
.
QueryContext
(
ctx
,
"SELECT pg_advisory_xact_lock($1)"
,
advisoryLockHash
(
key
))
if
err
!=
nil
{
release
()
return
nil
,
err
}
_
=
rows
.
Close
()
}
return
release
,
nil
}
func
(
r
*
userRepository
)
WithUserProfileIdentityTx
(
ctx
context
.
Context
,
fn
func
(
txCtx
context
.
Context
)
error
)
error
{
if
dbent
.
TxFromContext
(
ctx
)
!=
nil
{
return
fn
(
ctx
)
...
...
@@ -301,17 +412,18 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
client
:=
clientFromContext
(
txCtx
,
r
.
client
)
canonical
:=
input
.
Canonical
identity
,
err
:=
client
.
AuthIdentity
.
Query
()
.
identity
Records
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
strings
.
TrimSpace
(
canonical
.
ProviderType
)),
authidentity
.
ProviderKey
EQ
(
strings
.
TrimSpace
(
canonical
.
ProviderKey
)),
authidentity
.
ProviderKey
In
(
compatibleIdentityProviderKeys
(
canonical
.
ProviderType
,
canonical
.
ProviderKey
)
...
),
authidentity
.
ProviderSubjectEQ
(
strings
.
TrimSpace
(
canonical
.
ProviderSubject
)),
)
.
Only
(
txCtx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
All
(
txCtx
)
if
err
!=
nil
{
return
err
}
if
identity
!=
nil
&&
identity
.
UserID
!=
input
.
UserID
{
identity
:=
selectOwnedCompatibleIdentity
(
identityRecords
,
input
.
UserID
)
if
identity
==
nil
&&
hasCompatibleIdentityConflict
(
identityRecords
,
input
.
UserID
)
{
return
ErrAuthIdentityOwnershipConflict
}
if
identity
==
nil
{
...
...
@@ -328,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return
err
}
}
else
{
targetProviderKey
:=
canonicalizeCompatibleIdentityProviderKey
(
canonical
.
ProviderType
,
identity
.
ProviderKey
,
canonical
.
ProviderKey
)
update
:=
client
.
AuthIdentity
.
UpdateOneID
(
identity
.
ID
)
if
targetProviderKey
!=
""
&&
!
strings
.
EqualFold
(
targetProviderKey
,
identity
.
ProviderKey
)
{
update
=
update
.
SetProviderKey
(
targetProviderKey
)
}
if
input
.
Metadata
!=
nil
{
update
=
update
.
SetMetadata
(
copyMetadata
(
input
.
Metadata
))
}
...
...
@@ -346,20 +462,21 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
var
channel
*
dbent
.
AuthIdentityChannel
if
input
.
Channel
!=
nil
{
channel
,
err
=
client
.
AuthIdentityChannel
.
Query
()
.
channel
Records
,
err
:
=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
strings
.
TrimSpace
(
input
.
Channel
.
ProviderType
)),
authidentitychannel
.
ProviderKey
EQ
(
strings
.
TrimSpace
(
input
.
Channel
.
ProviderKey
)),
authidentitychannel
.
ProviderKey
In
(
compatibleIdentityProviderKeys
(
input
.
Channel
.
ProviderType
,
input
.
Channel
.
ProviderKey
)
...
),
authidentitychannel
.
ChannelEQ
(
strings
.
TrimSpace
(
input
.
Channel
.
Channel
)),
authidentitychannel
.
ChannelAppIDEQ
(
strings
.
TrimSpace
(
input
.
Channel
.
ChannelAppID
)),
authidentitychannel
.
ChannelSubjectEQ
(
strings
.
TrimSpace
(
input
.
Channel
.
ChannelSubject
)),
)
.
WithIdentity
()
.
Only
(
txCtx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
All
(
txCtx
)
if
err
!=
nil
{
return
err
}
if
channel
!=
nil
&&
channel
.
Edges
.
Identity
!=
nil
&&
channel
.
Edges
.
Identity
.
UserID
!=
input
.
UserID
{
channel
=
selectOwnedCompatibleChannel
(
channelRecords
,
input
.
UserID
)
if
channel
==
nil
&&
hasCompatibleChannelConflict
(
channelRecords
,
input
.
UserID
)
{
return
ErrAuthIdentityChannelOwnershipConflict
}
if
channel
==
nil
{
...
...
@@ -376,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return
err
}
}
else
{
targetProviderKey
:=
canonicalizeCompatibleIdentityProviderKey
(
input
.
Channel
.
ProviderType
,
channel
.
ProviderKey
,
input
.
Channel
.
ProviderKey
)
update
:=
client
.
AuthIdentityChannel
.
UpdateOneID
(
channel
.
ID
)
.
SetIdentityID
(
identity
.
ID
)
if
targetProviderKey
!=
""
&&
!
strings
.
EqualFold
(
targetProviderKey
,
channel
.
ProviderKey
)
{
update
=
update
.
SetProviderKey
(
targetProviderKey
)
}
if
input
.
ChannelMetadata
!=
nil
{
update
=
update
.
SetMetadata
(
copyMetadata
(
input
.
ChannelMetadata
))
}
...
...
@@ -397,6 +518,104 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return
result
,
nil
}
func
compatibleIdentityProviderKeys
(
providerType
,
providerKey
string
)
[]
string
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
providerKey
=
strings
.
TrimSpace
(
providerKey
)
if
providerKey
==
""
{
return
[]
string
{
providerKey
}
}
if
providerType
!=
"wechat"
{
return
[]
string
{
providerKey
}
}
keys
:=
[]
string
{
providerKey
}
if
!
strings
.
EqualFold
(
providerKey
,
"wechat-main"
)
{
keys
=
append
(
keys
,
"wechat-main"
)
}
if
!
strings
.
EqualFold
(
providerKey
,
"wechat"
)
{
keys
=
append
(
keys
,
"wechat"
)
}
return
keys
}
func
canonicalizeCompatibleIdentityProviderKey
(
providerType
,
existingKey
,
requestedKey
string
)
string
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
existingKey
=
strings
.
TrimSpace
(
existingKey
)
requestedKey
=
strings
.
TrimSpace
(
requestedKey
)
if
providerType
!=
"wechat"
{
if
requestedKey
!=
""
{
return
requestedKey
}
return
existingKey
}
if
strings
.
EqualFold
(
existingKey
,
"wechat"
)
||
strings
.
EqualFold
(
existingKey
,
"wechat-main"
)
||
strings
.
EqualFold
(
requestedKey
,
"wechat-main"
)
{
return
"wechat-main"
}
if
requestedKey
!=
""
{
return
requestedKey
}
return
existingKey
}
func
compatibleIdentityProviderKeyRank
(
providerType
,
providerKey
string
)
int
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
providerKey
=
strings
.
TrimSpace
(
providerKey
)
if
providerType
!=
"wechat"
{
return
0
}
switch
{
case
strings
.
EqualFold
(
providerKey
,
"wechat-main"
)
:
return
0
case
strings
.
EqualFold
(
providerKey
,
"wechat"
)
:
return
2
default
:
return
1
}
}
func
selectOwnedCompatibleIdentity
(
records
[]
*
dbent
.
AuthIdentity
,
userID
int64
)
*
dbent
.
AuthIdentity
{
var
selected
*
dbent
.
AuthIdentity
for
_
,
record
:=
range
records
{
if
record
.
UserID
!=
userID
{
continue
}
if
selected
==
nil
||
compatibleIdentityProviderKeyRank
(
record
.
ProviderType
,
record
.
ProviderKey
)
<
compatibleIdentityProviderKeyRank
(
selected
.
ProviderType
,
selected
.
ProviderKey
)
{
selected
=
record
}
}
return
selected
}
func
hasCompatibleIdentityConflict
(
records
[]
*
dbent
.
AuthIdentity
,
userID
int64
)
bool
{
for
_
,
record
:=
range
records
{
if
record
.
UserID
!=
userID
{
return
true
}
}
return
false
}
func
selectOwnedCompatibleChannel
(
records
[]
*
dbent
.
AuthIdentityChannel
,
userID
int64
)
*
dbent
.
AuthIdentityChannel
{
var
selected
*
dbent
.
AuthIdentityChannel
for
_
,
record
:=
range
records
{
if
record
.
Edges
.
Identity
==
nil
||
record
.
Edges
.
Identity
.
UserID
!=
userID
{
continue
}
if
selected
==
nil
||
compatibleIdentityProviderKeyRank
(
record
.
ProviderType
,
record
.
ProviderKey
)
<
compatibleIdentityProviderKeyRank
(
selected
.
ProviderType
,
selected
.
ProviderKey
)
{
selected
=
record
}
}
return
selected
}
func
hasCompatibleChannelConflict
(
records
[]
*
dbent
.
AuthIdentityChannel
,
userID
int64
)
bool
{
for
_
,
record
:=
range
records
{
if
record
.
Edges
.
Identity
!=
nil
&&
record
.
Edges
.
Identity
.
UserID
!=
userID
{
return
true
}
}
return
false
}
func
(
r
*
userRepository
)
RecordProviderGrant
(
ctx
context
.
Context
,
input
ProviderGrantRecordInput
)
(
bool
,
error
)
{
exec
:=
txAwareSQLExecutor
(
ctx
,
r
.
sql
,
r
.
client
)
if
exec
==
nil
{
...
...
@@ -422,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
}
func
(
r
*
userRepository
)
UpsertIdentityAdoptionDecision
(
ctx
context
.
Context
,
input
IdentityAdoptionDecisionInput
)
(
*
dbent
.
IdentityAdoptionDecision
,
error
)
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
if
_
,
err
:=
client
.
IdentityAdoptionDecision
.
Update
()
.
Where
(
identityadoptiondecision
.
IdentityIDEQ
(
*
input
.
IdentityID
),
dbpredicate
.
IdentityAdoptionDecision
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
identityadoptiondecision
.
FieldPendingAuthSessionID
)
s
.
Where
(
entsql
.
Or
(
entsql
.
IsNull
(
col
),
entsql
.
NEQ
(
col
,
input
.
PendingAuthSessionID
),
))
}),
)
.
ClearIdentityID
()
.
Save
(
ctx
);
err
!=
nil
{
return
nil
,
err
var
result
*
dbent
.
IdentityAdoptionDecision
err
:=
r
.
WithUserProfileIdentityTx
(
ctx
,
func
(
txCtx
context
.
Context
)
error
{
client
:=
clientFromContext
(
txCtx
,
r
.
client
)
releaseLocks
,
err
:=
lockRepositoryScopedKeys
(
txCtx
,
client
,
txAwareSQLExecutor
(
txCtx
,
r
.
sql
,
r
.
client
),
identityAdoptionDecisionLockKeys
(
input
.
PendingAuthSessionID
,
input
.
IdentityID
)
...
,
)
if
err
!=
nil
{
return
err
}
defer
releaseLocks
()
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
if
_
,
err
:=
client
.
IdentityAdoptionDecision
.
Update
()
.
Where
(
identityadoptiondecision
.
IdentityIDEQ
(
*
input
.
IdentityID
),
dbpredicate
.
IdentityAdoptionDecision
(
func
(
s
*
entsql
.
Selector
)
{
col
:=
s
.
C
(
identityadoptiondecision
.
FieldPendingAuthSessionID
)
s
.
Where
(
entsql
.
Or
(
entsql
.
IsNull
(
col
),
entsql
.
NEQ
(
col
,
input
.
PendingAuthSessionID
),
))
}),
)
.
ClearIdentityID
()
.
Save
(
txCtx
);
err
!=
nil
{
return
err
}
}
}
current
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
input
.
PendingAuthSessionID
))
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
return
nil
,
err
}
now
:=
time
.
Now
()
.
UTC
()
if
current
==
nil
{
create
:=
client
.
IdentityAdoptionDecision
.
Create
()
.
SetPendingAuthSessionID
(
input
.
PendingAuthSessionID
)
.
SetAdoptDisplayName
(
input
.
AdoptDisplayName
)
.
SetAdoptAvatar
(
input
.
AdoptAvatar
)
.
SetDecidedAt
(
now
)
if
input
.
IdentityID
!=
nil
{
SetDecidedAt
(
time
.
Now
()
.
UTC
()
)
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
create
=
create
.
SetIdentityID
(
*
input
.
IdentityID
)
}
return
create
.
Save
(
ctx
)
decisionID
,
err
:=
create
.
OnConflictColumns
(
identityadoptiondecision
.
FieldPendingAuthSessionID
)
.
UpdateNewValues
()
.
ID
(
txCtx
)
if
err
!=
nil
{
return
err
}
result
,
err
=
client
.
IdentityAdoptionDecision
.
Get
(
txCtx
,
decisionID
)
return
err
})
if
err
!=
nil
{
return
nil
,
err
}
return
result
,
nil
}
update
:=
client
.
IdentityAdoptionDecision
.
UpdateOneID
(
current
.
ID
)
.
SetAdoptDisplayName
(
input
.
AdoptDisplayName
)
.
SetAdoptAvatar
(
input
.
AdoptAvatar
)
if
input
.
IdentityID
!=
nil
{
update
=
update
.
SetIdentityID
(
*
input
.
IdentityID
)
func
identityAdoptionDecisionLockKeys
(
pendingAuthSessionID
int64
,
identityID
*
int64
)
[]
string
{
keys
:=
[]
string
{
fmt
.
Sprintf
(
"identity-adoption:pending:%d"
,
pendingAuthSessionID
)}
if
identityID
!=
nil
&&
*
identityID
>
0
{
keys
=
append
(
keys
,
fmt
.
Sprintf
(
"identity-adoption:identity:%d"
,
*
identityID
))
}
return
update
.
Save
(
ctx
)
return
keys
}
func
(
r
*
userRepository
)
GetIdentityAdoptionDecisionByPendingAuthSessionID
(
ctx
context
.
Context
,
pendingAuthSessionID
int64
)
(
*
dbent
.
IdentityAdoptionDecision
,
error
)
{
...
...
backend/internal/repository/user_profile_identity_repo_contract_test.go
View file @
ddf80f5e
...
...
@@ -10,6 +10,8 @@ import (
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
...
...
@@ -186,6 +188,79 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn
s
.
Require
()
.
ErrorIs
(
err
,
ErrAuthIdentityChannelOwnershipConflict
)
}
func
(
s
*
UserProfileIdentityRepoSuite
)
TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords
()
{
user
:=
s
.
mustCreateUser
(
"wechat-legacy-alias"
)
legacyIdentity
,
err
:=
s
.
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetProviderSubject
(
"union-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"source"
:
"legacy-alias"
})
.
Save
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
)
legacyChannel
,
err
:=
s
.
client
.
AuthIdentityChannel
.
Create
()
.
SetIdentityID
(
legacyIdentity
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetChannel
(
"oa"
)
.
SetChannelAppID
(
"wx-app-legacy"
)
.
SetChannelSubject
(
"openid-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"scene"
:
"legacy-alias"
})
.
Save
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
)
bound
,
err
:=
s
.
repo
.
BindAuthIdentityToUser
(
s
.
ctx
,
BindAuthIdentityInput
{
UserID
:
user
.
ID
,
Canonical
:
AuthIdentityKey
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
ProviderSubject
:
"union-legacy-123"
,
},
Channel
:
&
AuthIdentityChannelKey
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
Channel
:
"oa"
,
ChannelAppID
:
"wx-app-legacy"
,
ChannelSubject
:
"openid-legacy-123"
,
},
Metadata
:
map
[
string
]
any
{
"source"
:
"canonical-bind"
},
ChannelMetadata
:
map
[
string
]
any
{
"scene"
:
"canonical-bind"
},
})
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
NotNil
(
bound
)
s
.
Require
()
.
NotNil
(
bound
.
Identity
)
s
.
Require
()
.
NotNil
(
bound
.
Channel
)
s
.
Require
()
.
Equal
(
legacyIdentity
.
ID
,
bound
.
Identity
.
ID
)
s
.
Require
()
.
Equal
(
legacyChannel
.
ID
,
bound
.
Channel
.
ID
)
s
.
Require
()
.
Equal
(
"wechat-main"
,
bound
.
Identity
.
ProviderKey
)
s
.
Require
()
.
Equal
(
"wechat-main"
,
bound
.
Channel
.
ProviderKey
)
s
.
Require
()
.
Equal
(
"canonical-bind"
,
bound
.
Identity
.
Metadata
[
"source"
])
s
.
Require
()
.
Equal
(
"canonical-bind"
,
bound
.
Channel
.
Metadata
[
"scene"
])
identityCount
,
err
:=
s
.
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
UserIDEQ
(
user
.
ID
),
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderSubjectEQ
(
"union-legacy-123"
),
)
.
Count
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
1
,
identityCount
)
channelCount
,
err
:=
s
.
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ChannelEQ
(
"oa"
),
authidentitychannel
.
ChannelAppIDEQ
(
"wx-app-legacy"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-legacy-123"
),
)
.
Count
(
s
.
ctx
)
s
.
Require
()
.
NoError
(
err
)
s
.
Require
()
.
Equal
(
1
,
channelCount
)
}
func
(
s
*
UserProfileIdentityRepoSuite
)
TestCreateAuthIdentity_RejectsChannelProviderMismatch
()
{
user
:=
s
.
mustCreateUser
(
"provider-mismatch-create"
)
...
...
backend/internal/repository/user_profile_identity_repo_unit_test.go
0 → 100644
View file @
ddf80f5e
package
repository
import
(
"context"
"sync"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func
TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
user
:=
&
service
.
User
{
Email
:
"wechat-legacy@example.com"
,
Username
:
"wechat-legacy"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
user
))
legacyIdentity
,
err
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetProviderSubject
(
"union-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"source"
:
"legacy-alias"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
legacyChannel
,
err
:=
client
.
AuthIdentityChannel
.
Create
()
.
SetIdentityID
(
legacyIdentity
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetChannel
(
"oa"
)
.
SetChannelAppID
(
"wx-app-legacy"
)
.
SetChannelSubject
(
"openid-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"scene"
:
"legacy-alias"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
bound
,
err
:=
repo
.
BindAuthIdentityToUser
(
ctx
,
BindAuthIdentityInput
{
UserID
:
user
.
ID
,
Canonical
:
AuthIdentityKey
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
ProviderSubject
:
"union-legacy-123"
,
},
Channel
:
&
AuthIdentityChannelKey
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
Channel
:
"oa"
,
ChannelAppID
:
"wx-app-legacy"
,
ChannelSubject
:
"openid-legacy-123"
,
},
Metadata
:
map
[
string
]
any
{
"source"
:
"canonical-bind"
},
ChannelMetadata
:
map
[
string
]
any
{
"scene"
:
"canonical-bind"
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
bound
)
require
.
NotNil
(
t
,
bound
.
Identity
)
require
.
NotNil
(
t
,
bound
.
Channel
)
require
.
Equal
(
t
,
legacyIdentity
.
ID
,
bound
.
Identity
.
ID
)
require
.
Equal
(
t
,
legacyChannel
.
ID
,
bound
.
Channel
.
ID
)
require
.
Equal
(
t
,
"wechat-main"
,
bound
.
Identity
.
ProviderKey
)
require
.
Equal
(
t
,
"wechat-main"
,
bound
.
Channel
.
ProviderKey
)
reloadedIdentity
,
err
:=
client
.
AuthIdentity
.
Get
(
ctx
,
legacyIdentity
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"wechat-main"
,
reloadedIdentity
.
ProviderKey
)
require
.
Equal
(
t
,
"canonical-bind"
,
reloadedIdentity
.
Metadata
[
"source"
])
reloadedChannel
,
err
:=
client
.
AuthIdentityChannel
.
Get
(
ctx
,
legacyChannel
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"wechat-main"
,
reloadedChannel
.
ProviderKey
)
require
.
Equal
(
t
,
"canonical-bind"
,
reloadedChannel
.
Metadata
[
"scene"
])
identityCount
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
UserIDEQ
(
user
.
ID
),
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderSubjectEQ
(
"union-legacy-123"
),
)
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
identityCount
)
channelCount
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ChannelEQ
(
"oa"
),
authidentitychannel
.
ChannelAppIDEQ
(
"wx-app-legacy"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-legacy-123"
),
)
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
channelCount
)
}
func
TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
user
:=
&
service
.
User
{
Email
:
"repo-adoption@example.com"
,
Username
:
"repo-adoption"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
user
))
identity
,
err
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat-main"
)
.
SetProviderSubject
(
"union-repo-adoption"
)
.
SetMetadata
(
map
[
string
]
any
{})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
session
,
err
:=
client
.
PendingAuthSession
.
Create
()
.
SetSessionToken
(
"pending-repo-adoption"
)
.
SetIntent
(
"bind_current_user"
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat-main"
)
.
SetProviderSubject
(
"union-repo-adoption"
)
.
SetExpiresAt
(
time
.
Now
()
.
UTC
()
.
Add
(
15
*
time
.
Minute
))
.
SetUpstreamIdentityClaims
(
map
[
string
]
any
{
"provider_subject"
:
"union-repo-adoption"
})
.
SetLocalFlowState
(
map
[
string
]
any
{
"step"
:
"pending"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
firstCreateStarted
:=
make
(
chan
struct
{})
releaseFirstCreate
:=
make
(
chan
struct
{})
var
firstCreate
sync
.
Once
client
.
IdentityAdoptionDecision
.
Use
(
func
(
next
dbent
.
Mutator
)
dbent
.
Mutator
{
return
dbent
.
MutateFunc
(
func
(
ctx
context
.
Context
,
m
dbent
.
Mutation
)
(
dbent
.
Value
,
error
)
{
blocked
:=
false
if
m
.
Op
()
.
Is
(
dbent
.
OpCreate
)
{
firstCreate
.
Do
(
func
()
{
blocked
=
true
close
(
firstCreateStarted
)
})
}
if
blocked
{
<-
releaseFirstCreate
}
return
next
.
Mutate
(
ctx
,
m
)
})
})
type
adoptionResult
struct
{
decision
*
dbent
.
IdentityAdoptionDecision
err
error
}
input
:=
IdentityAdoptionDecisionInput
{
PendingAuthSessionID
:
session
.
ID
,
IdentityID
:
&
identity
.
ID
,
AdoptDisplayName
:
true
,
AdoptAvatar
:
true
,
}
results
:=
make
(
chan
adoptionResult
,
2
)
go
func
()
{
decision
,
err
:=
repo
.
UpsertIdentityAdoptionDecision
(
ctx
,
input
)
results
<-
adoptionResult
{
decision
:
decision
,
err
:
err
}
}()
<-
firstCreateStarted
go
func
()
{
decision
,
err
:=
repo
.
UpsertIdentityAdoptionDecision
(
ctx
,
input
)
results
<-
adoptionResult
{
decision
:
decision
,
err
:
err
}
}()
time
.
Sleep
(
100
*
time
.
Millisecond
)
close
(
releaseFirstCreate
)
first
:=
<-
results
second
:=
<-
results
require
.
NoError
(
t
,
first
.
err
)
require
.
NoError
(
t
,
second
.
err
)
require
.
NotNil
(
t
,
first
.
decision
)
require
.
NotNil
(
t
,
second
.
decision
)
require
.
Equal
(
t
,
first
.
decision
.
ID
,
second
.
decision
.
ID
)
count
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
loaded
,
err
:=
client
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
session
.
ID
))
.
Only
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
loaded
.
IdentityID
)
require
.
Equal
(
t
,
identity
.
ID
,
*
loaded
.
IdentityID
)
require
.
True
(
t
,
loaded
.
AdoptDisplayName
)
require
.
True
(
t
,
loaded
.
AdoptAvatar
)
}
backend/internal/repository/user_repo.go
View file @
ddf80f5e
...
...
@@ -52,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var
txClient
*
dbent
.
Client
txCtx
:=
ctx
if
err
==
nil
{
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txClient
=
tx
.
Client
()
txCtx
=
dbent
.
NewTxContext
(
ctx
,
tx
)
}
else
{
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if
existingTx
:=
dbent
.
TxFromContext
(
ctx
);
existingTx
!=
nil
{
...
...
@@ -64,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
}
releaseEmailLock
,
err
:=
lockRepositoryScopedKeys
(
txCtx
,
txClient
,
txAwareSQLExecutor
(
txCtx
,
r
.
sql
,
r
.
client
),
normalizedEmailUniquenessLockKey
(
userIn
.
Email
),
)
if
err
!=
nil
{
return
err
}
defer
releaseEmailLock
()
if
err
:=
ensureNormalizedEmailAvailableWithClient
(
txCtx
,
txClient
,
0
,
userIn
.
Email
);
err
!=
nil
{
return
err
}
created
,
err
:=
txClient
.
User
.
Create
()
.
SetEmail
(
userIn
.
Email
)
.
SetUsername
(
userIn
.
Username
)
.
...
...
@@ -76,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource
(
userSignupSourceOrDefault
(
userIn
.
SignupSource
))
.
SetNillableLastLoginAt
(
userIn
.
LastLoginAt
)
.
SetNillableLastActiveAt
(
userIn
.
LastActiveAt
)
.
Save
(
c
tx
)
Save
(
txC
tx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
nil
,
service
.
ErrEmailExists
)
}
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
c
tx
,
txClient
,
created
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
txC
tx
,
txClient
,
created
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
return
err
}
if
err
:=
ensureEmailAuthIdentityWithClient
(
c
tx
,
txClient
,
created
.
ID
,
created
.
Email
,
"user_repo_create"
);
err
!=
nil
{
if
err
:=
ensureEmailAuthIdentityWithClient
(
txC
tx
,
txClient
,
created
.
ID
,
created
.
Email
,
"user_repo_create"
);
err
!=
nil
{
return
err
}
...
...
@@ -154,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var
txClient
*
dbent
.
Client
txCtx
:=
ctx
if
err
==
nil
{
defer
func
()
{
_
=
tx
.
Rollback
()
}()
txClient
=
tx
.
Client
()
txCtx
=
dbent
.
NewTxContext
(
ctx
,
tx
)
}
else
{
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if
existingTx
:=
dbent
.
TxFromContext
(
ctx
);
existingTx
!=
nil
{
...
...
@@ -165,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient
=
r
.
client
}
}
existing
,
err
:=
clientFromContext
(
ctx
,
txClient
)
.
User
.
Get
(
ctx
,
userIn
.
ID
)
releaseEmailLock
,
err
:=
lockRepositoryScopedKeys
(
txCtx
,
txClient
,
txAwareSQLExecutor
(
txCtx
,
r
.
sql
,
r
.
client
),
normalizedEmailUniquenessLockKey
(
userIn
.
Email
),
)
if
err
!=
nil
{
return
err
}
defer
releaseEmailLock
()
if
err
:=
ensureNormalizedEmailAvailableWithClient
(
txCtx
,
txClient
,
userIn
.
ID
,
userIn
.
Email
);
err
!=
nil
{
return
err
}
existing
,
err
:=
clientFromContext
(
txCtx
,
txClient
)
.
User
.
Get
(
txCtx
,
userIn
.
ID
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
nil
)
}
...
...
@@ -197,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if
userIn
.
BalanceNotifyThreshold
==
nil
{
updateOp
=
updateOp
.
ClearBalanceNotifyThreshold
()
}
updated
,
err
:=
updateOp
.
Save
(
c
tx
)
updated
,
err
:=
updateOp
.
Save
(
txC
tx
)
if
err
!=
nil
{
return
translatePersistenceError
(
err
,
service
.
ErrUserNotFound
,
service
.
ErrEmailExists
)
}
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
c
tx
,
txClient
,
updated
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
if
err
:=
r
.
syncUserAllowedGroupsWithClient
(
txC
tx
,
txClient
,
updated
.
ID
,
userIn
.
AllowedGroups
);
err
!=
nil
{
return
err
}
if
err
:=
replaceEmailAuthIdentityWithClient
(
c
tx
,
txClient
,
updated
.
ID
,
oldEmail
,
updated
.
Email
,
"user_repo_update"
);
err
!=
nil
{
if
err
:=
replaceEmailAuthIdentityWithClient
(
txC
tx
,
txClient
,
updated
.
ID
,
oldEmail
,
updated
.
Email
,
"user_repo_update"
);
err
!=
nil
{
return
err
}
...
...
@@ -704,8 +739,28 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
return
r
.
client
.
User
.
Query
()
.
Where
(
userEmailLookupPredicate
(
email
))
.
Exist
(
ctx
)
}
func
ensureNormalizedEmailAvailableWithClient
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
userID
int64
,
email
string
)
error
{
client
=
clientFromContext
(
ctx
,
client
)
if
client
==
nil
{
return
nil
}
matches
,
err
:=
client
.
User
.
Query
()
.
Where
(
userEmailLookupPredicate
(
email
))
.
All
(
ctx
)
if
err
!=
nil
{
return
err
}
for
_
,
match
:=
range
matches
{
if
match
.
ID
!=
userID
{
return
service
.
ErrEmailExists
}
}
return
nil
}
func
userEmailLookupPredicate
(
email
string
)
predicate
.
User
{
normalized
:=
strings
.
ToLower
(
strings
.
TrimSpac
e
(
email
)
)
normalized
:=
normalizeEmailLookupValu
e
(
email
)
if
normalized
==
""
{
return
dbuser
.
EmailEQ
(
email
)
}
...
...
@@ -719,6 +774,18 @@ func userEmailLookupPredicate(email string) predicate.User {
})
}
func
normalizeEmailLookupValue
(
email
string
)
string
{
return
strings
.
ToLower
(
strings
.
TrimSpace
(
email
))
}
func
normalizedEmailUniquenessLockKey
(
email
string
)
string
{
normalized
:=
normalizeEmailLookupValue
(
email
)
if
normalized
==
""
{
return
""
}
return
"users:normalized-email:"
+
normalized
}
func
(
r
*
userRepository
)
AddGroupToAllowedGroups
(
ctx
context
.
Context
,
userID
int64
,
groupID
int64
)
error
{
client
:=
clientFromContext
(
ctx
,
r
.
client
)
err
:=
client
.
UserAllowedGroup
.
Create
()
.
...
...
@@ -853,11 +920,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
}
func
userSignupSourceOrDefault
(
signupSource
string
)
string
{
signupSource
=
strings
.
TrimSpace
(
signupSource
)
if
signupSource
==
""
{
switch
strings
.
TrimSpace
(
strings
.
ToLower
(
signupSource
))
{
case
""
,
"email"
:
return
"email"
case
"linuxdo"
,
"wechat"
,
"oidc"
:
return
strings
.
TrimSpace
(
strings
.
ToLower
(
signupSource
))
default
:
return
"email"
}
return
signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
...
...
backend/internal/repository/user_repo_email_lookup_unit_test.go
View file @
ddf80f5e
...
...
@@ -3,7 +3,10 @@ package repository
import
(
"context"
"database/sql"
"fmt"
"sync"
"testing"
"time"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
...
...
@@ -18,9 +21,10 @@ import (
func
newUserEntRepo
(
t
*
testing
.
T
)
(
*
userRepository
,
*
dbent
.
Client
)
{
t
.
Helper
()
db
,
err
:=
sql
.
Open
(
"sqlite"
,
"file:user_repo_email_lookup
?mode=memory&cache=shared
"
)
db
,
err
:=
sql
.
Open
(
"sqlite"
,
fmt
.
Sprintf
(
"file:%s
?mode=memory&cache=shared
&_fk=1"
,
t
.
Name
())
)
require
.
NoError
(
t
,
err
)
t
.
Cleanup
(
func
()
{
_
=
db
.
Close
()
})
db
.
SetMaxOpenConns
(
10
)
_
,
err
=
db
.
Exec
(
"PRAGMA foreign_keys = ON"
)
require
.
NoError
(
t
,
err
)
...
...
@@ -67,3 +71,157 @@ func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T)
require
.
NoError
(
t
,
err
)
require
.
True
(
t
,
exists
)
}
func
TestUserRepositoryCreateRejectsNormalizedEmailDuplicate
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
err
:=
repo
.
Create
(
ctx
,
&
service
.
User
{
Email
:
" Existing@Example.com "
,
Username
:
"existing-user"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
})
require
.
NoError
(
t
,
err
)
err
=
repo
.
Create
(
ctx
,
&
service
.
User
{
Email
:
"existing@example.com"
,
Username
:
"duplicate-user"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
})
require
.
ErrorIs
(
t
,
err
,
service
.
ErrEmailExists
)
}
func
TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate
(
t
*
testing
.
T
)
{
repo
,
_
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
first
:=
&
service
.
User
{
Email
:
" Existing@Example.com "
,
Username
:
"existing-user"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
first
))
second
:=
&
service
.
User
{
Email
:
"second@example.com"
,
Username
:
"second-user"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
}
require
.
NoError
(
t
,
repo
.
Create
(
ctx
,
second
))
second
.
Email
=
" existing@example.com "
err
:=
repo
.
Update
(
ctx
,
second
)
require
.
ErrorIs
(
t
,
err
,
service
.
ErrEmailExists
)
}
func
TestUserRepositoryGetByEmailReportsNormalizedEmailConflict
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
_
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"Conflict@Example.com"
)
.
SetUsername
(
"conflict-user-1"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
client
.
User
.
Create
()
.
SetEmail
(
" conflict@example.com "
)
.
SetUsername
(
"conflict-user-2"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
service
.
RoleUser
)
.
SetStatus
(
service
.
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
_
,
err
=
repo
.
GetByEmail
(
ctx
,
"conflict@example.com"
)
require
.
Error
(
t
,
err
)
require
.
ErrorContains
(
t
,
err
,
"normalized email lookup matched multiple users"
)
}
func
TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency
(
t
*
testing
.
T
)
{
repo
,
client
:=
newUserEntRepo
(
t
)
ctx
:=
context
.
Background
()
firstCreateStarted
:=
make
(
chan
struct
{})
releaseFirstCreate
:=
make
(
chan
struct
{})
var
firstCreate
sync
.
Once
client
.
User
.
Use
(
func
(
next
dbent
.
Mutator
)
dbent
.
Mutator
{
return
dbent
.
MutateFunc
(
func
(
ctx
context
.
Context
,
m
dbent
.
Mutation
)
(
dbent
.
Value
,
error
)
{
blocked
:=
false
if
m
.
Op
()
.
Is
(
dbent
.
OpCreate
)
{
firstCreate
.
Do
(
func
()
{
blocked
=
true
close
(
firstCreateStarted
)
})
}
if
blocked
{
<-
releaseFirstCreate
}
return
next
.
Mutate
(
ctx
,
m
)
})
})
type
createResult
struct
{
err
error
}
results
:=
make
(
chan
createResult
,
2
)
go
func
()
{
results
<-
createResult
{
err
:
repo
.
Create
(
ctx
,
&
service
.
User
{
Email
:
" Race@Example.com "
,
Username
:
"race-user-1"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
})}
}()
<-
firstCreateStarted
go
func
()
{
results
<-
createResult
{
err
:
repo
.
Create
(
ctx
,
&
service
.
User
{
Email
:
"race@example.com"
,
Username
:
"race-user-2"
,
PasswordHash
:
"hash"
,
Role
:
service
.
RoleUser
,
Status
:
service
.
StatusActive
,
})}
}()
time
.
Sleep
(
100
*
time
.
Millisecond
)
close
(
releaseFirstCreate
)
first
:=
<-
results
second
:=
<-
results
errors
:=
[]
error
{
first
.
err
,
second
.
err
}
successes
:=
0
conflicts
:=
0
for
_
,
err
:=
range
errors
{
switch
err
{
case
nil
:
successes
++
case
service
.
ErrEmailExists
:
conflicts
++
default
:
t
.
Fatalf
(
"unexpected create error: %v"
,
err
)
}
}
require
.
Equal
(
t
,
1
,
successes
)
require
.
Equal
(
t
,
1
,
conflicts
)
count
,
err
:=
client
.
User
.
Query
()
.
Where
(
userEmailLookupPredicate
(
"race@example.com"
))
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
count
)
}
backend/internal/server/api_contract_test.go
View file @
ddf80f5e
...
...
@@ -85,7 +85,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
...
...
@@ -93,7 +93,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
...
...
@@ -101,7 +101,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"identity_bindings": {
...
...
@@ -122,7 +122,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
...
...
@@ -130,7 +130,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
...
...
@@ -138,7 +138,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"auth_bindings": {
...
...
@@ -159,7 +159,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
...
...
@@ -167,7 +167,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
...
...
@@ -175,7 +175,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/
bind/
start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"run_mode": "standard"
...
...
@@ -784,6 +784,198 @@ func TestAPIContracts(t *testing.T) {
}
}`
,
},
{
name
:
"GET /api/v1/admin/settings falls back to config oauth defaults"
,
setup
:
func
(
t
*
testing
.
T
,
deps
*
contractDeps
)
{
t
.
Helper
()
deps
.
cfg
.
OIDC
=
config
.
OIDCConnectConfig
{
Enabled
:
true
,
ProviderName
:
"ConfigOIDC"
,
ClientID
:
"oidc-config-client"
,
ClientSecret
:
"oidc-config-secret"
,
IssuerURL
:
"https://issuer.example.com"
,
RedirectURL
:
"https://api.example.com/api/v1/auth/oauth/oidc/callback"
,
FrontendRedirectURL
:
"/auth/oidc/callback"
,
Scopes
:
"openid email profile"
,
TokenAuthMethod
:
"client_secret_post"
,
UsePKCE
:
true
,
ValidateIDToken
:
true
,
AllowedSigningAlgs
:
"RS256,ES256,PS256"
,
ClockSkewSeconds
:
120
,
}
deps
.
cfg
.
WeChat
=
config
.
WeChatConnectConfig
{
Enabled
:
true
,
OpenEnabled
:
true
,
OpenAppID
:
"wx-open-config"
,
OpenAppSecret
:
"wx-open-secret"
,
Mode
:
"open"
,
Scopes
:
"snsapi_login"
,
FrontendRedirectURL
:
"/auth/wechat/callback"
,
}
deps
.
settingRepo
.
SetAll
(
map
[
string
]
string
{
service
.
SettingKeyRegistrationEnabled
:
"true"
,
service
.
SettingKeyEmailVerifyEnabled
:
"false"
,
service
.
SettingKeyRegistrationEmailSuffixWhitelist
:
"[]"
,
})
},
method
:
http
.
MethodGet
,
path
:
"/api/v1/admin/settings"
,
wantStatus
:
http
.
StatusOK
,
wantJSON
:
`{
"code": 0,
"message": "success",
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"invitation_code_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "",
"smtp_port": 587,
"smtp_username": "",
"smtp_password_configured": false,
"smtp_from_email": "",
"smtp_from_name": "",
"smtp_use_tls": false,
"turnstile_enabled": false,
"turnstile_site_key": "",
"turnstile_secret_key_configured": false,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"oidc_connect_enabled": true,
"oidc_connect_provider_name": "ConfigOIDC",
"oidc_connect_client_id": "oidc-config-client",
"oidc_connect_client_secret_configured": true,
"oidc_connect_issuer_url": "https://issuer.example.com",
"oidc_connect_discovery_url": "",
"oidc_connect_authorize_url": "",
"oidc_connect_token_url": "",
"oidc_connect_userinfo_url": "",
"oidc_connect_jwks_url": "",
"oidc_connect_scopes": "openid email profile",
"oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
"oidc_connect_require_email_verified": false,
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subscription to API Conversion Platform",
"api_base_url": "",
"contact_info": "",
"doc_url": "",
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50],
"custom_menu_items": [],
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_openai": "gpt-4o",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_antigravity": "gemini-2.5-pro",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
"ops_metrics_interval_seconds": 60,
"min_claude_code_version": "",
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"enable_cch_signing": false,
"web_search_emulation_enabled": false,
"payment_visible_method_alipay_source": "",
"payment_visible_method_wxpay_source": "",
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_balance_recharge_multiplier": 0,
"payment_recharge_fee_rate": 0,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
"wechat_connect_mode": "open",
"wechat_connect_open_enabled": true,
"wechat_connect_open_app_id": "wx-open-config",
"wechat_connect_open_app_secret_configured": true,
"wechat_connect_mp_enabled": false,
"wechat_connect_mp_app_id": "wx-open-config",
"wechat_connect_mp_app_secret_configured": true,
"wechat_connect_mobile_enabled": false,
"wechat_connect_mobile_app_id": "wx-open-config",
"wechat_connect_mobile_app_secret_configured": true,
"wechat_connect_redirect_url": "",
"wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
"wechat_connect_scopes": "snsapi_login",
"auth_source_default_email_balance": 0,
"auth_source_default_email_concurrency": 5,
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
"auth_source_default_linuxdo_grant_on_signup": false,
"auth_source_default_linuxdo_grant_on_first_bind": false,
"auth_source_default_oidc_balance": 0,
"auth_source_default_oidc_concurrency": 5,
"auth_source_default_oidc_subscriptions": [],
"auth_source_default_oidc_grant_on_signup": false,
"auth_source_default_oidc_grant_on_first_bind": false,
"auth_source_default_wechat_balance": 0,
"auth_source_default_wechat_concurrency": 5,
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
"force_email_on_third_party_signup": false
}
}`
,
},
{
name
:
"POST /api/v1/admin/accounts/bulk-update"
,
method
:
http
.
MethodPost
,
...
...
@@ -827,6 +1019,7 @@ func TestAPIContracts(t *testing.T) {
type
contractDeps
struct
{
now
time
.
Time
router
http
.
Handler
cfg
*
config
.
Config
apiKeyRepo
*
stubApiKeyRepo
groupRepo
*
stubGroupRepo
userSubRepo
*
stubUserSubscriptionRepo
...
...
@@ -947,6 +1140,7 @@ func newContractDeps(t *testing.T) *contractDeps {
return
&
contractDeps
{
now
:
now
,
router
:
r
,
cfg
:
cfg
,
apiKeyRepo
:
apiKeyRepo
,
groupRepo
:
groupRepo
,
userSubRepo
:
userSubRepo
,
...
...
backend/internal/server/middleware/backend_mode_guard.go
View file @
ddf80f5e
...
...
@@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
}
}
func
backendModeAllowsAuthPath
(
path
string
)
bool
{
path
=
strings
.
ToLower
(
strings
.
TrimSpace
(
path
))
for
_
,
suffix
:=
range
[]
string
{
"/auth/login"
,
"/auth/login/2fa"
,
"/auth/logout"
,
"/auth/refresh"
}
{
if
strings
.
HasSuffix
(
path
,
suffix
)
{
return
true
}
}
for
_
,
suffix
:=
range
[]
string
{
"/auth/oauth/linuxdo/callback"
,
"/auth/oauth/wechat/callback"
,
"/auth/oauth/wechat/payment/callback"
,
"/auth/oauth/oidc/callback"
,
"/auth/oauth/linuxdo/complete-registration"
,
"/auth/oauth/wechat/complete-registration"
,
"/auth/oauth/oidc/complete-registration"
,
"/auth/oauth/linuxdo/create-account"
,
"/auth/oauth/wechat/create-account"
,
"/auth/oauth/oidc/create-account"
,
"/auth/oauth/linuxdo/bind-login"
,
"/auth/oauth/wechat/bind-login"
,
"/auth/oauth/oidc/bind-login"
,
}
{
if
strings
.
HasSuffix
(
path
,
suffix
)
{
return
true
}
}
return
strings
.
Contains
(
path
,
"/auth/oauth/pending/"
)
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
// Allows the minimal auth surface admins still need in backend mode, including
// OAuth callbacks and pending continuations. Handler-level backend mode checks
// still enforce admin-only login and forbid self-service registration.
func
BackendModeAuthGuard
(
settingService
*
service
.
SettingService
)
gin
.
HandlerFunc
{
return
func
(
c
*
gin
.
Context
)
{
if
settingService
==
nil
||
!
settingService
.
IsBackendModeEnabled
(
c
.
Request
.
Context
())
{
c
.
Next
()
return
}
path
:=
c
.
Request
.
URL
.
Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes
:=
[]
string
{
"/auth/login"
,
"/auth/login/2fa"
,
"/auth/logout"
,
"/auth/refresh"
}
for
_
,
suffix
:=
range
allowedSuffixes
{
if
strings
.
HasSuffix
(
path
,
suffix
)
{
c
.
Next
()
return
}
if
backendModeAllowsAuthPath
(
c
.
Request
.
URL
.
Path
)
{
c
.
Next
()
return
}
response
.
Forbidden
(
c
,
"Backend mode is active. Registration and self-service auth flows are disabled."
)
c
.
Abort
()
...
...
backend/internal/server/middleware/backend_mode_guard_test.go
View file @
ddf80f5e
...
...
@@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
path
:
"/api/v1/auth/refresh"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_blocks_linuxdo_oauth_start"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/linuxdo/start"
,
wantStatus
:
http
.
StatusForbidden
,
},
{
name
:
"enabled_allows_linuxdo_oauth_callback"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/linuxdo/callback"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_blocks_wechat_oauth_start"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/wechat/start"
,
wantStatus
:
http
.
StatusForbidden
,
},
{
name
:
"enabled_allows_wechat_oauth_callback"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/wechat/callback"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_blocks_wechat_payment_oauth_start"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/wechat/payment/start"
,
wantStatus
:
http
.
StatusForbidden
,
},
{
name
:
"enabled_allows_wechat_payment_oauth_callback"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/wechat/payment/callback"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_blocks_oidc_oauth_start"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/oidc/start"
,
wantStatus
:
http
.
StatusForbidden
,
},
{
name
:
"enabled_allows_oidc_oauth_callback"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/oidc/callback"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_oauth_pending_exchange"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/pending/exchange"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_oauth_pending_send_verify_code"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/pending/send-verify-code"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_oauth_pending_create_account"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/pending/create-account"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_oauth_pending_bind_login"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/pending/bind-login"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_provider_bind_login"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/oidc/bind-login"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_provider_create_account"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/wechat/create-account"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_allows_legacy_complete_registration"
,
enabled
:
"true"
,
path
:
"/api/v1/auth/oauth/linuxdo/complete-registration"
,
wantStatus
:
http
.
StatusOK
,
},
{
name
:
"enabled_blocks_register"
,
enabled
:
"true"
,
...
...
backend/internal/server/routes/auth.go
View file @
ddf80f5e
...
...
@@ -63,8 +63,20 @@ func RegisterAuthRoutes(
FailureMode
:
middleware
.
RateLimitFailClose
,
}),
h
.
Auth
.
ResetPassword
)
auth
.
GET
(
"/oauth/linuxdo/start"
,
h
.
Auth
.
LinuxDoOAuthStart
)
auth
.
GET
(
"/oauth/linuxdo/bind/start"
,
func
(
c
*
gin
.
Context
)
{
query
:=
c
.
Request
.
URL
.
Query
()
query
.
Set
(
"intent"
,
"bind_current_user"
)
c
.
Request
.
URL
.
RawQuery
=
query
.
Encode
()
h
.
Auth
.
LinuxDoOAuthStart
(
c
)
})
auth
.
GET
(
"/oauth/linuxdo/callback"
,
h
.
Auth
.
LinuxDoOAuthCallback
)
auth
.
GET
(
"/oauth/wechat/start"
,
h
.
Auth
.
WeChatOAuthStart
)
auth
.
GET
(
"/oauth/wechat/bind/start"
,
func
(
c
*
gin
.
Context
)
{
query
:=
c
.
Request
.
URL
.
Query
()
query
.
Set
(
"intent"
,
"bind_current_user"
)
c
.
Request
.
URL
.
RawQuery
=
query
.
Encode
()
h
.
Auth
.
WeChatOAuthStart
(
c
)
})
auth
.
GET
(
"/oauth/wechat/callback"
,
h
.
Auth
.
WeChatOAuthCallback
)
auth
.
GET
(
"/oauth/wechat/payment/start"
,
h
.
Auth
.
WeChatPaymentOAuthStart
)
auth
.
GET
(
"/oauth/wechat/payment/callback"
,
h
.
Auth
.
WeChatPaymentOAuthCallback
)
...
...
@@ -129,6 +141,12 @@ func RegisterAuthRoutes(
h
.
Auth
.
CreateWeChatOAuthAccount
,
)
auth
.
GET
(
"/oauth/oidc/start"
,
h
.
Auth
.
OIDCOAuthStart
)
auth
.
GET
(
"/oauth/oidc/bind/start"
,
func
(
c
*
gin
.
Context
)
{
query
:=
c
.
Request
.
URL
.
Query
()
query
.
Set
(
"intent"
,
"bind_current_user"
)
c
.
Request
.
URL
.
RawQuery
=
query
.
Encode
()
h
.
Auth
.
OIDCOAuthStart
(
c
)
})
auth
.
GET
(
"/oauth/oidc/callback"
,
h
.
Auth
.
OIDCOAuthCallback
)
auth
.
POST
(
"/oauth/oidc/complete-registration"
,
rateLimiter
.
LimitWithOptions
(
"oauth-oidc-complete"
,
10
,
time
.
Minute
,
middleware
.
RateLimitOptions
{
...
...
@@ -164,23 +182,6 @@ func RegisterAuthRoutes(
authenticated
.
GET
(
"/auth/me"
,
h
.
Auth
.
GetCurrentUser
)
// 撤销所有会话(需要认证)
authenticated
.
POST
(
"/auth/revoke-all-sessions"
,
h
.
Auth
.
RevokeAllSessions
)
authenticated
.
GET
(
"/auth/oauth/linuxdo/bind/start"
,
func
(
c
*
gin
.
Context
)
{
query
:=
c
.
Request
.
URL
.
Query
()
query
.
Set
(
"intent"
,
"bind_current_user"
)
c
.
Request
.
URL
.
RawQuery
=
query
.
Encode
()
h
.
Auth
.
LinuxDoOAuthStart
(
c
)
})
authenticated
.
GET
(
"/auth/oauth/oidc/bind/start"
,
func
(
c
*
gin
.
Context
)
{
query
:=
c
.
Request
.
URL
.
Query
()
query
.
Set
(
"intent"
,
"bind_current_user"
)
c
.
Request
.
URL
.
RawQuery
=
query
.
Encode
()
h
.
Auth
.
OIDCOAuthStart
(
c
)
})
authenticated
.
GET
(
"/auth/oauth/wechat/bind/start"
,
func
(
c
*
gin
.
Context
)
{
query
:=
c
.
Request
.
URL
.
Query
()
query
.
Set
(
"intent"
,
"bind_current_user"
)
c
.
Request
.
URL
.
RawQuery
=
query
.
Encode
()
h
.
Auth
.
WeChatOAuthStart
(
c
)
})
authenticated
.
POST
(
"/auth/oauth/bind-token"
,
h
.
Auth
.
PrepareOAuthBindAccessTokenCookie
)
}
}
backend/internal/server/routes/payment.go
View file @
ddf80f5e
...
...
@@ -44,9 +44,9 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
// Signed resume-token recovery is the
support
ed public lookup path.
// The legacy anonymous out_trade_no verify endpoint
is kept only
as a
//
compatibility shim that returns HTTP 410 Gone
.
// Signed resume-token recovery is the
preferr
ed public lookup path.
// The legacy anonymous out_trade_no verify endpoint
remains available
as a
//
persisted-state compatibility path for staggered upgrades
.
public
:=
v1
.
Group
(
"/payment/public"
)
{
public
.
POST
(
"/orders/verify"
,
paymentHandler
.
VerifyOrderPublic
)
...
...
backend/internal/service/account_test_service.go
View file @
ddf80f5e
...
...
@@ -419,6 +419,7 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
// testOpenAIAccountConnection tests an OpenAI account's connection
func
(
s
*
AccountTestService
)
testOpenAIAccountConnection
(
c
*
gin
.
Context
,
account
*
Account
,
modelID
string
,
prompt
string
)
error
{
ctx
:=
c
.
Request
.
Context
()
_
=
prompt
// Default to openai.DefaultTestModel for OpenAI testing
testModelID
:=
modelID
...
...
backend/internal/service/admin_service.go
View file @
ddf80f5e
...
...
@@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
if
providerKey
==
""
||
providerSubject
==
""
{
return
nil
,
infraerrors
.
BadRequest
(
"INVALID_INPUT"
,
"provider_type, provider_key, and provider_subject are required"
)
}
canonicalProviderKey
:=
canonicalAdminAuthIdentityProviderKey
(
providerType
,
""
,
providerKey
)
compatibleProviderKeys
:=
compatibleAdminAuthIdentityProviderKeys
(
providerType
,
providerKey
)
var
issuer
*
string
if
input
.
Issuer
!=
nil
{
...
...
@@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
}
defer
func
()
{
_
=
tx
.
Rollback
()
}()
identity
,
err
:=
tx
.
AuthIdentity
.
Query
()
.
identity
Records
,
err
:=
tx
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
providerType
),
authidentity
.
ProviderKey
EQ
(
p
roviderKey
),
authidentity
.
ProviderKey
In
(
compatibleP
roviderKey
s
...
),
authidentity
.
ProviderSubjectEQ
(
providerSubject
),
)
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
InternalServer
(
"ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED"
,
"failed to inspect auth identity ownership"
)
.
WithCause
(
err
)
}
if
identity
!=
nil
&&
identity
.
UserID
!=
userID
{
if
hasAdminAuthIdentityOwnershipConflict
(
identityRecords
,
userID
)
{
return
nil
,
infraerrors
.
Conflict
(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT"
,
"auth identity already belongs to another user"
)
}
identity
:=
selectOwnedAdminAuthIdentity
(
identityRecords
,
userID
)
if
identity
==
nil
{
create
:=
tx
.
AuthIdentity
.
Create
()
.
SetUserID
(
userID
)
.
SetProviderType
(
providerType
)
.
SetProviderKey
(
p
roviderKey
)
.
SetProviderKey
(
canonicalP
roviderKey
)
.
SetProviderSubject
(
providerSubject
)
.
SetVerifiedAt
(
verifiedAt
)
if
issuer
!=
nil
{
...
...
@@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return
nil
,
infraerrors
.
InternalServer
(
"ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED"
,
"failed to save auth identity"
)
.
WithCause
(
err
)
}
}
else
{
update
:=
tx
.
AuthIdentity
.
UpdateOneID
(
identity
.
ID
)
.
SetVerifiedAt
(
verifiedAt
)
update
:=
tx
.
AuthIdentity
.
UpdateOneID
(
identity
.
ID
)
.
SetVerifiedAt
(
verifiedAt
)
.
SetProviderKey
(
canonicalProviderKey
)
if
issuer
!=
nil
{
update
=
update
.
SetIssuer
(
*
issuer
)
}
...
...
@@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
var
channel
*
dbent
.
AuthIdentityChannel
if
channelInput
!=
nil
{
channel
,
err
=
tx
.
AuthIdentityChannel
.
Query
()
.
channel
Records
,
err
:
=
tx
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
providerType
),
authidentitychannel
.
ProviderKey
EQ
(
p
roviderKey
),
authidentitychannel
.
ProviderKey
In
(
compatibleP
roviderKey
s
...
),
authidentitychannel
.
ChannelEQ
(
channelInput
.
Channel
),
authidentitychannel
.
ChannelAppIDEQ
(
channelInput
.
ChannelAppID
),
authidentitychannel
.
ChannelSubjectEQ
(
channelInput
.
ChannelSubject
),
)
.
WithIdentity
()
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
All
(
ctx
)
if
err
!=
nil
{
return
nil
,
infraerrors
.
InternalServer
(
"ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED"
,
"failed to inspect auth identity channel ownership"
)
.
WithCause
(
err
)
}
if
c
ha
nnel
!=
nil
&&
channel
.
Edges
.
Identity
!=
nil
&&
channel
.
Edges
.
Identity
.
UserID
!=
userID
{
if
ha
sAdminAuthIdentityChannelOwnershipConflict
(
channelRecords
,
userID
)
{
return
nil
,
infraerrors
.
Conflict
(
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT"
,
"auth identity channel already belongs to another user"
)
}
channel
=
selectOwnedAdminAuthIdentityChannel
(
channelRecords
,
userID
)
if
channel
==
nil
{
create
:=
tx
.
AuthIdentityChannel
.
Create
()
.
SetIdentityID
(
identity
.
ID
)
.
SetProviderType
(
providerType
)
.
SetProviderKey
(
p
roviderKey
)
.
SetProviderKey
(
canonicalP
roviderKey
)
.
SetChannel
(
channelInput
.
Channel
)
.
SetChannelAppID
(
channelInput
.
ChannelAppID
)
.
SetChannelSubject
(
channelInput
.
ChannelSubject
)
...
...
@@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return
nil
,
infraerrors
.
InternalServer
(
"ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED"
,
"failed to save auth identity channel"
)
.
WithCause
(
err
)
}
}
else
{
update
:=
tx
.
AuthIdentityChannel
.
UpdateOneID
(
channel
.
ID
)
.
SetIdentityID
(
identity
.
ID
)
update
:=
tx
.
AuthIdentityChannel
.
UpdateOneID
(
channel
.
ID
)
.
SetIdentityID
(
identity
.
ID
)
.
SetProviderKey
(
canonicalProviderKey
)
if
channelInput
.
Metadata
!=
nil
{
update
=
update
.
SetMetadata
(
cloneAdminAuthIdentityMetadata
(
channelInput
.
Metadata
))
}
...
...
@@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return
buildAdminBoundAuthIdentity
(
identity
,
channel
),
nil
}
func
compatibleAdminAuthIdentityProviderKeys
(
providerType
,
providerKey
string
)
[]
string
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
providerKey
=
strings
.
TrimSpace
(
providerKey
)
if
providerKey
==
""
{
return
[]
string
{
providerKey
}
}
if
providerType
!=
"wechat"
{
return
[]
string
{
providerKey
}
}
keys
:=
[]
string
{
providerKey
}
if
!
strings
.
EqualFold
(
providerKey
,
"wechat-main"
)
{
keys
=
append
(
keys
,
"wechat-main"
)
}
if
!
strings
.
EqualFold
(
providerKey
,
"wechat"
)
{
keys
=
append
(
keys
,
"wechat"
)
}
return
keys
}
func
canonicalAdminAuthIdentityProviderKey
(
providerType
,
existingKey
,
requestedKey
string
)
string
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
existingKey
=
strings
.
TrimSpace
(
existingKey
)
requestedKey
=
strings
.
TrimSpace
(
requestedKey
)
if
providerType
!=
"wechat"
{
if
requestedKey
!=
""
{
return
requestedKey
}
return
existingKey
}
if
strings
.
EqualFold
(
existingKey
,
"wechat"
)
||
strings
.
EqualFold
(
existingKey
,
"wechat-main"
)
||
strings
.
EqualFold
(
requestedKey
,
"wechat-main"
)
{
return
"wechat-main"
}
if
requestedKey
!=
""
{
return
requestedKey
}
return
existingKey
}
func
adminAuthIdentityProviderKeyRank
(
providerType
,
providerKey
string
)
int
{
providerType
=
strings
.
TrimSpace
(
strings
.
ToLower
(
providerType
))
providerKey
=
strings
.
TrimSpace
(
providerKey
)
if
providerType
!=
"wechat"
{
return
0
}
switch
{
case
strings
.
EqualFold
(
providerKey
,
"wechat-main"
)
:
return
0
case
strings
.
EqualFold
(
providerKey
,
"wechat"
)
:
return
2
default
:
return
1
}
}
func
selectOwnedAdminAuthIdentity
(
records
[]
*
dbent
.
AuthIdentity
,
userID
int64
)
*
dbent
.
AuthIdentity
{
var
selected
*
dbent
.
AuthIdentity
for
_
,
record
:=
range
records
{
if
record
.
UserID
!=
userID
{
continue
}
if
selected
==
nil
||
adminAuthIdentityProviderKeyRank
(
record
.
ProviderType
,
record
.
ProviderKey
)
<
adminAuthIdentityProviderKeyRank
(
selected
.
ProviderType
,
selected
.
ProviderKey
)
{
selected
=
record
}
}
return
selected
}
func
hasAdminAuthIdentityOwnershipConflict
(
records
[]
*
dbent
.
AuthIdentity
,
userID
int64
)
bool
{
for
_
,
record
:=
range
records
{
if
record
.
UserID
!=
userID
{
return
true
}
}
return
false
}
func
selectOwnedAdminAuthIdentityChannel
(
records
[]
*
dbent
.
AuthIdentityChannel
,
userID
int64
)
*
dbent
.
AuthIdentityChannel
{
var
selected
*
dbent
.
AuthIdentityChannel
for
_
,
record
:=
range
records
{
if
record
.
Edges
.
Identity
==
nil
||
record
.
Edges
.
Identity
.
UserID
!=
userID
{
continue
}
if
selected
==
nil
||
adminAuthIdentityProviderKeyRank
(
record
.
ProviderType
,
record
.
ProviderKey
)
<
adminAuthIdentityProviderKeyRank
(
selected
.
ProviderType
,
selected
.
ProviderKey
)
{
selected
=
record
}
}
return
selected
}
func
hasAdminAuthIdentityChannelOwnershipConflict
(
records
[]
*
dbent
.
AuthIdentityChannel
,
userID
int64
)
bool
{
for
_
,
record
:=
range
records
{
if
record
.
Edges
.
Identity
!=
nil
&&
record
.
Edges
.
Identity
.
UserID
!=
userID
{
return
true
}
}
return
false
}
func
normalizeAdminBindChannelInput
(
input
*
AdminBindAuthIdentityChannelInput
)
*
AdminBindAuthIdentityChannelInput
{
if
input
==
nil
{
return
nil
...
...
backend/internal/service/admin_service_auth_identity_binding_test.go
View file @
ddf80f5e
...
...
@@ -188,6 +188,93 @@ func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
require
.
Equal
(
t
,
"second"
,
identities
[
0
]
.
Metadata
[
"source"
])
}
func
TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords
(
t
*
testing
.
T
)
{
client
:=
newAdminServiceAuthIdentityBindingTestClient
(
t
)
ctx
:=
context
.
Background
()
user
,
err
:=
client
.
User
.
Create
()
.
SetEmail
(
"wechat-alias@example.com"
)
.
SetPasswordHash
(
"hash"
)
.
SetRole
(
RoleUser
)
.
SetStatus
(
StatusActive
)
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
legacyIdentity
,
err
:=
client
.
AuthIdentity
.
Create
()
.
SetUserID
(
user
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetProviderSubject
(
"union-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"source"
:
"legacy"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
legacyChannel
,
err
:=
client
.
AuthIdentityChannel
.
Create
()
.
SetIdentityID
(
legacyIdentity
.
ID
)
.
SetProviderType
(
"wechat"
)
.
SetProviderKey
(
"wechat"
)
.
SetChannel
(
"open"
)
.
SetChannelAppID
(
"wx-open"
)
.
SetChannelSubject
(
"openid-legacy-123"
)
.
SetMetadata
(
map
[
string
]
any
{
"scene"
:
"legacy"
})
.
Save
(
ctx
)
require
.
NoError
(
t
,
err
)
svc
:=
&
adminServiceImpl
{
userRepo
:
&
userRepoStub
{
user
:
&
User
{
ID
:
user
.
ID
,
Email
:
user
.
Email
,
Status
:
StatusActive
}},
entClient
:
client
,
}
result
,
err
:=
svc
.
BindUserAuthIdentity
(
ctx
,
user
.
ID
,
AdminBindAuthIdentityInput
{
ProviderType
:
"wechat"
,
ProviderKey
:
"wechat-main"
,
ProviderSubject
:
"union-legacy-123"
,
Metadata
:
map
[
string
]
any
{
"source"
:
"admin-repair"
},
Channel
:
&
AdminBindAuthIdentityChannelInput
{
Channel
:
"open"
,
ChannelAppID
:
"wx-open"
,
ChannelSubject
:
"openid-legacy-123"
,
Metadata
:
map
[
string
]
any
{
"scene"
:
"admin-repair"
},
},
})
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
result
)
require
.
Equal
(
t
,
"wechat-main"
,
result
.
ProviderKey
)
require
.
NotNil
(
t
,
result
.
Channel
)
require
.
Equal
(
t
,
"open"
,
result
.
Channel
.
Channel
)
identity
,
err
:=
client
.
AuthIdentity
.
Get
(
ctx
,
legacyIdentity
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"wechat-main"
,
identity
.
ProviderKey
)
require
.
Equal
(
t
,
"admin-repair"
,
identity
.
Metadata
[
"source"
])
channel
,
err
:=
client
.
AuthIdentityChannel
.
Get
(
ctx
,
legacyChannel
.
ID
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
"wechat-main"
,
channel
.
ProviderKey
)
require
.
Equal
(
t
,
legacyIdentity
.
ID
,
channel
.
IdentityID
)
require
.
Equal
(
t
,
"admin-repair"
,
channel
.
Metadata
[
"scene"
])
identityCount
,
err
:=
client
.
AuthIdentity
.
Query
()
.
Where
(
authidentity
.
ProviderTypeEQ
(
"wechat"
),
authidentity
.
ProviderSubjectEQ
(
"union-legacy-123"
),
)
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
identityCount
)
channelCount
,
err
:=
client
.
AuthIdentityChannel
.
Query
()
.
Where
(
authidentitychannel
.
ProviderTypeEQ
(
"wechat"
),
authidentitychannel
.
ChannelEQ
(
"open"
),
authidentitychannel
.
ChannelAppIDEQ
(
"wx-open"
),
authidentitychannel
.
ChannelSubjectEQ
(
"openid-legacy-123"
),
)
.
Count
(
ctx
)
require
.
NoError
(
t
,
err
)
require
.
Equal
(
t
,
1
,
channelCount
)
}
func
TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType
(
t
*
testing
.
T
)
{
client
:=
newAdminServiceAuthIdentityBindingTestClient
(
t
)
ctx
:=
context
.
Background
()
...
...
backend/internal/service/auth_email_binding.go
View file @
ddf80f5e
...
...
@@ -11,6 +11,7 @@ import (
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors
"github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// BindEmailIdentity verifies and binds a local email/password identity to the
...
...
@@ -69,6 +70,7 @@ func (s *AuthService) BindEmailIdentity(
if
err
:=
s
.
updateBoundEmailIdentityTx
(
ctx
,
currentUser
,
normalizedEmail
,
hashedPassword
,
firstRealEmailBind
);
err
!=
nil
{
return
nil
,
err
}
s
.
revokeEmailIdentitySessions
(
ctx
,
userID
)
return
currentUser
,
nil
}
...
...
@@ -87,6 +89,7 @@ func (s *AuthService) BindEmailIdentity(
}
}
s
.
revokeEmailIdentitySessions
(
ctx
,
userID
)
return
currentUser
,
nil
}
...
...
@@ -219,6 +222,12 @@ func (s *AuthService) updateBoundEmailIdentityWithClient(
return
nil
}
func
(
s
*
AuthService
)
revokeEmailIdentitySessions
(
ctx
context
.
Context
,
userID
int64
)
{
if
err
:=
s
.
RevokeAllUserSessions
(
ctx
,
userID
);
err
!=
nil
{
logger
.
LegacyPrintf
(
"service.auth"
,
"[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v"
,
userID
,
err
)
}
}
func
replaceBoundEmailAuthIdentityWithClient
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
...
...
backend/internal/service/auth_oauth_email_flow.go
View file @
ddf80f5e
...
...
@@ -14,10 +14,14 @@ import (
func
normalizeOAuthSignupSource
(
signupSource
string
)
string
{
signupSource
=
strings
.
TrimSpace
(
strings
.
ToLower
(
signupSource
))
if
signupSource
==
""
{
switch
signupSource
{
case
""
,
"email"
:
return
"email"
case
"linuxdo"
,
"wechat"
,
"oidc"
:
return
signupSource
default
:
return
"email"
}
return
signupSource
}
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
...
...
@@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return
nil
,
nil
,
fmt
.
Errorf
(
"hash password: %w"
,
err
)
}
signupSource
=
strings
.
TrimSpace
(
strings
.
ToLower
(
signupSource
))
if
signupSource
==
""
{
signupSource
=
"email"
}
signupSource
=
normalizeOAuthSignupSource
(
signupSource
)
grantPlan
:=
s
.
resolveSignupGrantPlan
(
ctx
,
signupSource
)
user
:=
&
User
{
...
...
@@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
Balance
:
grantPlan
.
Balance
,
Concurrency
:
grantPlan
.
Concurrency
,
Status
:
StatusActive
,
SignupSource
:
signupSource
,
}
if
err
:=
s
.
userRepo
.
Create
(
ctx
,
user
);
err
!=
nil
{
...
...
backend/internal/service/auth_oauth_email_flow_test.go
View file @
ddf80f5e
...
...
@@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
require
.
Empty
(
t
,
redeemRepo
.
updateCalls
)
}
func
TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser
(
t
*
testing
.
T
)
{
userRepo
:=
&
userRepoStub
{
nextID
:
42
}
emailCache
:=
&
emailCacheStub
{
data
:
&
VerificationCodeData
{
Code
:
"246810"
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
ExpiresAt
:
time
.
Now
()
.
UTC
()
.
Add
(
15
*
time
.
Minute
),
},
}
authService
:=
newOAuthEmailFlowAuthService
(
userRepo
,
&
redeemCodeRepoStub
{},
&
refreshTokenCacheStub
{},
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
emailCache
,
)
tokenPair
,
user
,
err
:=
authService
.
RegisterOAuthEmailAccount
(
context
.
Background
(),
"fresh@example.com"
,
"secret-123"
,
"246810"
,
""
,
" OIDC "
,
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
tokenPair
)
require
.
NotNil
(
t
,
user
)
require
.
Len
(
t
,
userRepo
.
created
,
1
)
require
.
Equal
(
t
,
"oidc"
,
userRepo
.
created
[
0
]
.
SignupSource
)
}
func
TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail
(
t
*
testing
.
T
)
{
userRepo
:=
&
userRepoStub
{
nextID
:
43
}
emailCache
:=
&
emailCacheStub
{
data
:
&
VerificationCodeData
{
Code
:
"246810"
,
Attempts
:
0
,
CreatedAt
:
time
.
Now
()
.
UTC
(),
ExpiresAt
:
time
.
Now
()
.
UTC
()
.
Add
(
15
*
time
.
Minute
),
},
}
authService
:=
newOAuthEmailFlowAuthService
(
userRepo
,
&
redeemCodeRepoStub
{},
&
refreshTokenCacheStub
{},
map
[
string
]
string
{
SettingKeyRegistrationEnabled
:
"true"
,
SettingKeyEmailVerifyEnabled
:
"true"
,
},
emailCache
,
)
tokenPair
,
user
,
err
:=
authService
.
RegisterOAuthEmailAccount
(
context
.
Background
(),
"fallback@example.com"
,
"secret-123"
,
"246810"
,
""
,
"github"
,
)
require
.
NoError
(
t
,
err
)
require
.
NotNil
(
t
,
tokenPair
)
require
.
NotNil
(
t
,
user
)
require
.
Len
(
t
,
userRepo
.
created
,
1
)
require
.
Equal
(
t
,
"email"
,
userRepo
.
created
[
0
]
.
SignupSource
)
}
func
TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage
(
t
*
testing
.
T
)
{
userRepo
:=
&
userRepoStub
{}
redeemRepo
:=
&
redeemCodeRepoStub
{
...
...
backend/internal/service/auth_pending_identity_service.go
View file @
ddf80f5e
...
...
@@ -5,10 +5,15 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
"sort"
"strings"
"sync"
"time"
"entgo.io/ent/dialect"
dbent
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
...
...
@@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
entClient
*
dbent
.
Client
}
var
authPendingIdentityScopedKeyLocks
=
newAuthPendingIdentityScopedKeyLockRegistry
()
type
authPendingIdentityScopedKeyLockRegistry
struct
{
mu
sync
.
Mutex
locks
map
[
string
]
*
authPendingIdentityScopedKeyLockEntry
}
type
authPendingIdentityScopedKeyLockEntry
struct
{
mu
sync
.
Mutex
refs
int
}
func
newAuthPendingIdentityScopedKeyLockRegistry
()
*
authPendingIdentityScopedKeyLockRegistry
{
return
&
authPendingIdentityScopedKeyLockRegistry
{
locks
:
make
(
map
[
string
]
*
authPendingIdentityScopedKeyLockEntry
),
}
}
func
(
r
*
authPendingIdentityScopedKeyLockRegistry
)
lock
(
keys
...
string
)
func
()
{
normalized
:=
normalizeAuthPendingIdentityLockKeys
(
keys
...
)
if
len
(
normalized
)
==
0
{
return
func
()
{}
}
entries
:=
make
([]
*
authPendingIdentityScopedKeyLockEntry
,
0
,
len
(
normalized
))
r
.
mu
.
Lock
()
for
_
,
key
:=
range
normalized
{
entry
:=
r
.
locks
[
key
]
if
entry
==
nil
{
entry
=
&
authPendingIdentityScopedKeyLockEntry
{}
r
.
locks
[
key
]
=
entry
}
entry
.
refs
++
entries
=
append
(
entries
,
entry
)
}
r
.
mu
.
Unlock
()
for
_
,
entry
:=
range
entries
{
entry
.
mu
.
Lock
()
}
return
func
()
{
for
i
:=
len
(
entries
)
-
1
;
i
>=
0
;
i
--
{
entries
[
i
]
.
mu
.
Unlock
()
}
r
.
mu
.
Lock
()
defer
r
.
mu
.
Unlock
()
for
idx
,
key
:=
range
normalized
{
entry
:=
entries
[
idx
]
entry
.
refs
--
if
entry
.
refs
==
0
{
delete
(
r
.
locks
,
key
)
}
}
}
}
func
normalizeAuthPendingIdentityLockKeys
(
keys
...
string
)
[]
string
{
if
len
(
keys
)
==
0
{
return
nil
}
deduped
:=
make
(
map
[
string
]
struct
{},
len
(
keys
))
for
_
,
key
:=
range
keys
{
trimmed
:=
strings
.
TrimSpace
(
key
)
if
trimmed
==
""
{
continue
}
deduped
[
trimmed
]
=
struct
{}{}
}
if
len
(
deduped
)
==
0
{
return
nil
}
normalized
:=
make
([]
string
,
0
,
len
(
deduped
))
for
key
:=
range
deduped
{
normalized
=
append
(
normalized
,
key
)
}
sort
.
Strings
(
normalized
)
return
normalized
}
func
authPendingIdentityAdvisoryLockHash
(
key
string
)
int64
{
hasher
:=
fnv
.
New64a
()
_
,
_
=
hasher
.
Write
([]
byte
(
key
))
return
int64
(
hasher
.
Sum64
())
}
func
lockAuthPendingIdentityKeys
(
ctx
context
.
Context
,
client
*
dbent
.
Client
,
keys
...
string
)
(
func
(),
error
)
{
release
:=
authPendingIdentityScopedKeyLocks
.
lock
(
keys
...
)
normalized
:=
normalizeAuthPendingIdentityLockKeys
(
keys
...
)
if
len
(
normalized
)
==
0
||
client
==
nil
||
client
.
Driver
()
.
Dialect
()
!=
dialect
.
Postgres
{
return
release
,
nil
}
for
_
,
key
:=
range
normalized
{
var
rows
entsql
.
Rows
if
err
:=
client
.
Driver
()
.
Query
(
ctx
,
"SELECT pg_advisory_xact_lock($1)"
,
[]
any
{
authPendingIdentityAdvisoryLockHash
(
key
)},
&
rows
);
err
!=
nil
{
release
()
return
nil
,
err
}
_
=
rows
.
Close
()
}
return
release
,
nil
}
func
pendingIdentityAdoptionLockKeys
(
pendingAuthSessionID
int64
,
identityID
*
int64
)
[]
string
{
keys
:=
[]
string
{
fmt
.
Sprintf
(
"pending-auth-adoption:pending:%d"
,
pendingAuthSessionID
)}
if
identityID
!=
nil
&&
*
identityID
>
0
{
keys
=
append
(
keys
,
fmt
.
Sprintf
(
"pending-auth-adoption:identity:%d"
,
*
identityID
))
}
return
keys
}
func
NewAuthPendingIdentityService
(
entClient
*
dbent
.
Client
)
*
AuthPendingIdentityService
{
return
&
AuthPendingIdentityService
{
entClient
:
entClient
}
}
...
...
@@ -236,16 +357,66 @@ func (s *AuthPendingIdentityService) consumeSession(
return
nil
,
err
}
sanitizedLocalFlowState
:=
sanitizePendingAuthLocalFlowState
(
session
.
LocalFlowState
)
now
:=
time
.
Now
()
.
UTC
()
updated
,
err
:=
s
.
entClient
.
PendingAuthSession
.
UpdateOneID
(
session
.
ID
)
.
update
:=
s
.
entClient
.
PendingAuthSession
.
UpdateOneID
(
session
.
ID
)
.
Where
(
pendingauthsession
.
ConsumedAtIsNil
(),
pendingauthsession
.
ExpiresAtGTE
(
now
),
pendingauthsession
.
Or
(
pendingauthsession
.
CompletionCodeExpiresAtIsNil
(),
pendingauthsession
.
CompletionCodeExpiresAtGTE
(
now
),
),
)
.
SetConsumedAt
(
now
)
.
SetLocalFlowState
(
sanitizedLocalFlowState
)
.
SetCompletionCodeHash
(
""
)
.
ClearCompletionCodeExpiresAt
()
.
Save
(
ctx
)
if
err
!=
nil
{
ClearCompletionCodeExpiresAt
()
if
expectedBrowserSessionKey
:=
strings
.
TrimSpace
(
session
.
BrowserSessionKey
);
expectedBrowserSessionKey
!=
""
{
update
=
update
.
Where
(
pendingauthsession
.
BrowserSessionKeyEQ
(
expectedBrowserSessionKey
))
}
updated
,
err
:=
update
.
Save
(
ctx
)
if
err
==
nil
{
return
updated
,
nil
}
if
!
dbent
.
IsNotFound
(
err
)
{
return
nil
,
err
}
current
,
currentErr
:=
s
.
entClient
.
PendingAuthSession
.
Get
(
ctx
,
session
.
ID
)
if
currentErr
!=
nil
{
if
dbent
.
IsNotFound
(
currentErr
)
{
return
nil
,
ErrPendingAuthSessionNotFound
}
return
nil
,
currentErr
}
if
err
:=
validatePendingSessionState
(
current
,
browserSessionKey
,
expiredErr
,
consumedErr
);
err
!=
nil
{
return
nil
,
err
}
return
updated
,
nil
return
nil
,
consumedErr
}
func
sanitizePendingAuthLocalFlowState
(
localFlowState
map
[
string
]
any
)
map
[
string
]
any
{
sanitized
:=
copyPendingMap
(
localFlowState
)
if
len
(
sanitized
)
==
0
{
return
sanitized
}
rawCompletion
,
ok
:=
sanitized
[
"completion_response"
]
if
!
ok
{
return
sanitized
}
completion
,
ok
:=
rawCompletion
.
(
map
[
string
]
any
)
if
!
ok
{
return
sanitized
}
cleanedCompletion
:=
copyPendingMap
(
completion
)
for
_
,
key
:=
range
[]
string
{
"access_token"
,
"refresh_token"
,
"expires_in"
,
"token_type"
}
{
delete
(
cleanedCompletion
,
key
)
}
sanitized
[
"completion_response"
]
=
cleanedCompletion
return
sanitized
}
func
validatePendingSessionState
(
session
*
dbent
.
PendingAuthSession
,
browserSessionKey
string
,
expiredErr
error
,
consumedErr
error
)
error
{
...
...
@@ -274,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
return
nil
,
fmt
.
Errorf
(
"pending auth ent client is not configured"
)
}
tx
,
err
:=
s
.
entClient
.
Tx
(
ctx
)
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
dbent
.
ErrTxStarted
)
{
return
nil
,
err
}
client
:=
s
.
entClient
txCtx
:=
ctx
if
err
==
nil
{
defer
func
()
{
_
=
tx
.
Rollback
()
}()
client
=
tx
.
Client
()
txCtx
=
dbent
.
NewTxContext
(
ctx
,
tx
)
}
else
if
existingTx
:=
dbent
.
TxFromContext
(
ctx
);
existingTx
!=
nil
{
client
=
existingTx
.
Client
()
}
releaseLocks
,
err
:=
lockAuthPendingIdentityKeys
(
txCtx
,
client
,
pendingIdentityAdoptionLockKeys
(
input
.
PendingAuthSessionID
,
input
.
IdentityID
)
...
)
if
err
!=
nil
{
return
nil
,
err
}
defer
releaseLocks
()
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
if
_
,
err
:=
s
.
entC
lient
.
IdentityAdoptionDecision
.
Update
()
.
if
_
,
err
:=
c
lient
.
IdentityAdoptionDecision
.
Update
()
.
Where
(
identityadoptiondecision
.
IdentityIDEQ
(
*
input
.
IdentityID
),
dbpredicate
.
IdentityAdoptionDecision
(
func
(
s
*
entsql
.
Selector
)
{
...
...
@@ -287,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
}),
)
.
ClearIdentityID
()
.
Save
(
c
tx
);
err
!=
nil
{
Save
(
txC
tx
);
err
!=
nil
{
return
nil
,
err
}
}
existing
,
err
:=
s
.
entClient
.
IdentityAdoptionDecision
.
Query
()
.
Where
(
identityadoptiondecision
.
PendingAuthSessionIDEQ
(
input
.
PendingAuthSessionID
))
.
Only
(
ctx
)
if
err
!=
nil
&&
!
dbent
.
IsNotFound
(
err
)
{
create
:=
client
.
IdentityAdoptionDecision
.
Create
()
.
SetPendingAuthSessionID
(
input
.
PendingAuthSessionID
)
.
SetAdoptDisplayName
(
input
.
AdoptDisplayName
)
.
SetAdoptAvatar
(
input
.
AdoptAvatar
)
.
SetDecidedAt
(
time
.
Now
()
.
UTC
())
if
input
.
IdentityID
!=
nil
&&
*
input
.
IdentityID
>
0
{
create
=
create
.
SetIdentityID
(
*
input
.
IdentityID
)
}
decisionID
,
err
:=
create
.
OnConflictColumns
(
identityadoptiondecision
.
FieldPendingAuthSessionID
)
.
UpdateNewValues
()
.
ID
(
txCtx
)
if
err
!=
nil
{
return
nil
,
err
}
if
existing
==
nil
{
create
:=
s
.
entClient
.
IdentityAdoptionDecision
.
Create
()
.
SetPendingAuthSessionID
(
input
.
PendingAuthSessionID
)
.
SetAdoptDisplayName
(
input
.
AdoptDisplayName
)
.
SetAdoptAvatar
(
input
.
AdoptAvatar
)
.
SetDecidedAt
(
time
.
Now
()
.
UTC
())
if
input
.
IdentityID
!=
nil
{
create
=
create
.
SetIdentityID
(
*
input
.
IdentityID
)
}
return
create
.
Save
(
ctx
)
decision
,
err
:=
client
.
IdentityAdoptionDecision
.
Get
(
txCtx
,
decisionID
)
if
err
!=
nil
{
return
nil
,
err
}
update
:=
s
.
entClient
.
IdentityAdoptionDecision
.
UpdateOneID
(
existing
.
ID
)
.
SetAdoptDisplayName
(
input
.
AdoptDisplayName
)
.
SetAdoptAvatar
(
input
.
AdoptAvatar
)
if
input
.
IdentityID
!=
nil
{
update
=
update
.
SetIdentityID
(
*
input
.
IdentityID
)
if
tx
!=
nil
{
if
err
:=
tx
.
Commit
();
err
!=
nil
{
return
nil
,
err
}
}
return
update
.
Save
(
ctx
)
return
decision
,
nil
}
func
copyPendingMap
(
in
map
[
string
]
any
)
map
[
string
]
any
{
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment