Unverified Commit ddf80f5e authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1799 from IanShaw027/rebuild/auth-identity-foundation

fix(auth,payment,profile): 修复认证身份和支付系统的后续问题
parents 4d0483f5 c048ca80
......@@ -94,6 +94,24 @@ func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
func TestMigrationChecksumCompatibilityRules_CoverEditedUpgradeCompatibilityMigrations(t *testing.T) {
for _, name := range []string{
"109_auth_identity_compat_backfill.sql",
"110_pending_auth_and_provider_default_grants.sql",
"112_add_payment_order_provider_key_snapshot.sql",
"115_auth_identity_legacy_external_backfill.sql",
"116_auth_identity_legacy_external_safety_reports.sql",
"118_wechat_dual_mode_and_auth_source_defaults.sql",
"120_enforce_payment_orders_out_trade_no_unique_notx.sql",
"123_fix_legacy_auth_source_grant_on_signup_defaults.sql",
} {
rule, ok := migrationChecksumCompatibilityRules[name]
require.Truef(t, ok, "missing compatibility rule for %s", name)
require.NotEmpty(t, rule.fileChecksum)
require.NotEmpty(t, rule.acceptedDBChecksum)
}
}
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
......
......@@ -116,6 +116,84 @@ CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_FailsFastOnDuplicatePrecheck(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}).AddRow("dup-out-trade-no", 2))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
Data: []byte(`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "duplicate out_trade_no")
require.Contains(t, err.Error(), "dup-out-trade-no")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_PaymentOrdersOutTradeNoUniqueMigration_DropsInvalidIndexBeforeRetry(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectQuery("SELECT out_trade_no, COUNT\\(\\*\\) AS duplicate_count FROM payment_orders").
WillReturnRows(sqlmock.NewRows([]string{"out_trade_no", "duplicate_count"}))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("paymentorder_out_trade_no_unique").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no_unique").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("120_enforce_payment_orders_out_trade_no_unique_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"120_enforce_payment_orders_out_trade_no_unique_notx.sql": &fstest.MapFile{
Data: []byte(`
CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS paymentorder_out_trade_no_unique
ON payment_orders (out_trade_no)
WHERE out_trade_no <> '';
DROP INDEX CONCURRENTLY IF EXISTS paymentorder_out_trade_no;
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
......
......@@ -89,6 +89,35 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func TestMigrationsRunner_AuthIdentityAndPaymentSchemaStayAligned(t *testing.T) {
tx := testTx(t)
requireColumn(t, tx, "auth_identity_migration_reports", "report_type", "character varying", 80, false)
requireColumn(t, tx, "users", "signup_source", "character varying", 20, false)
requireColumnDefaultContains(t, tx, "users", "signup_source", "email")
requireConstraintDefinitionContains(
t,
tx,
"users",
"users_signup_source_check",
"signup_source",
"'email'",
"'linuxdo'",
"'wechat'",
"'oidc'",
)
requireForeignKeyOnDelete(t, tx, "auth_identities", "user_id", "users", "CASCADE")
requireForeignKeyOnDelete(t, tx, "auth_identity_channels", "identity_id", "auth_identities", "CASCADE")
requireForeignKeyOnDelete(t, tx, "pending_auth_sessions", "target_user_id", "users", "SET NULL")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "pending_auth_session_id", "pending_auth_sessions", "CASCADE")
requireForeignKeyOnDelete(t, tx, "identity_adoption_decisions", "identity_id", "auth_identities", "SET NULL")
requireIndex(t, tx, "payment_orders", "paymentorder_out_trade_no")
requirePartialUniqueIndexDefinition(t, tx, "payment_orders", "paymentorder_out_trade_no", "out_trade_no", "WHERE")
requireIndexAbsent(t, tx, "payment_orders", "paymentorder_out_trade_no_unique")
}
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
......@@ -106,6 +135,118 @@ SELECT EXISTS (
require.True(t, exists, "expected index %s on %s", index, table)
}
func requireIndexAbsent(t *testing.T, tx *sql.Tx, table, index string) {
t.Helper()
var exists bool
err := tx.QueryRowContext(context.Background(), `
SELECT EXISTS (
SELECT 1
FROM pg_indexes
WHERE schemaname = 'public'
AND tablename = $1
AND indexname = $2
)
`, table, index).Scan(&exists)
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
require.False(t, exists, "expected index %s on %s to be absent", index, table)
}
func requirePartialUniqueIndexDefinition(t *testing.T, tx *sql.Tx, table, index string, fragments ...string) {
t.Helper()
var (
unique bool
def string
)
err := tx.QueryRowContext(context.Background(), `
SELECT
i.indisunique,
pg_get_indexdef(i.indexrelid)
FROM pg_class idx
JOIN pg_index i ON i.indexrelid = idx.oid
JOIN pg_class tbl ON tbl.oid = i.indrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND idx.relname = $2
`, table, index).Scan(&unique, &def)
require.NoError(t, err, "query index definition for %s.%s", table, index)
require.True(t, unique, "expected index %s on %s to be unique", index, table)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected index definition for %s.%s to contain %q", table, index, fragment)
}
}
func requireForeignKeyOnDelete(t *testing.T, tx *sql.Tx, table, column, refTable, expected string) {
t.Helper()
var actual string
err := tx.QueryRowContext(context.Background(), `
SELECT CASE c.confdeltype
WHEN 'a' THEN 'NO ACTION'
WHEN 'r' THEN 'RESTRICT'
WHEN 'c' THEN 'CASCADE'
WHEN 'n' THEN 'SET NULL'
WHEN 'd' THEN 'SET DEFAULT'
END
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
JOIN pg_class ref_tbl ON ref_tbl.oid = c.confrelid
JOIN pg_attribute attr ON attr.attrelid = tbl.oid AND attr.attnum = ANY(c.conkey)
WHERE ns.nspname = 'public'
AND c.contype = 'f'
AND tbl.relname = $1
AND attr.attname = $2
AND ref_tbl.relname = $3
LIMIT 1
`, table, column, refTable).Scan(&actual)
require.NoError(t, err, "query foreign key action for %s.%s -> %s", table, column, refTable)
require.Equal(t, expected, actual, "unexpected ON DELETE action for %s.%s -> %s", table, column, refTable)
}
func requireConstraintDefinitionContains(t *testing.T, tx *sql.Tx, table, constraint string, fragments ...string) {
t.Helper()
var def string
err := tx.QueryRowContext(context.Background(), `
SELECT pg_get_constraintdef(c.oid)
FROM pg_constraint c
JOIN pg_class tbl ON tbl.oid = c.conrelid
JOIN pg_namespace ns ON ns.oid = tbl.relnamespace
WHERE ns.nspname = 'public'
AND tbl.relname = $1
AND c.conname = $2
`, table, constraint).Scan(&def)
require.NoError(t, err, "query constraint definition for %s.%s", table, constraint)
for _, fragment := range fragments {
require.Contains(t, def, fragment, "expected constraint definition for %s.%s to contain %q", table, constraint, fragment)
}
}
func requireColumnDefaultContains(t *testing.T, tx *sql.Tx, table, column string, fragments ...string) {
t.Helper()
var columnDefault sql.NullString
err := tx.QueryRowContext(context.Background(), `
SELECT column_default
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`, table, column).Scan(&columnDefault)
require.NoError(t, err, "query column_default for %s.%s", table, column)
require.True(t, columnDefault.Valid, "expected column_default for %s.%s", table, column)
for _, fragment := range fragments {
require.Contains(t, columnDefault.String, fragment, "expected default for %s.%s to contain %q", table, column, fragment)
}
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
......
......@@ -4,11 +4,15 @@ import (
"context"
"database/sql"
"fmt"
"hash/fnv"
"reflect"
"sort"
"strings"
"sync"
"time"
"unsafe"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
......@@ -120,6 +124,113 @@ type sqlQueryExecutor interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
var repositoryScopedKeyLocks = newScopedKeyLockRegistry()
type scopedKeyLockRegistry struct {
mu sync.Mutex
locks map[string]*scopedKeyLockEntry
}
type scopedKeyLockEntry struct {
mu sync.Mutex
refs int
}
func newScopedKeyLockRegistry() *scopedKeyLockRegistry {
return &scopedKeyLockRegistry{
locks: make(map[string]*scopedKeyLockEntry),
}
}
func (r *scopedKeyLockRegistry) lock(keys ...string) func() {
normalized := normalizeLockKeys(keys...)
if len(normalized) == 0 {
return func() {}
}
entries := make([]*scopedKeyLockEntry, 0, len(normalized))
r.mu.Lock()
for _, key := range normalized {
entry := r.locks[key]
if entry == nil {
entry = &scopedKeyLockEntry{}
r.locks[key] = entry
}
entry.refs++
entries = append(entries, entry)
}
r.mu.Unlock()
for _, entry := range entries {
entry.mu.Lock()
}
return func() {
for i := len(entries) - 1; i >= 0; i-- {
entries[i].mu.Unlock()
}
r.mu.Lock()
defer r.mu.Unlock()
for idx, key := range normalized {
entry := entries[idx]
entry.refs--
if entry.refs == 0 {
delete(r.locks, key)
}
}
}
}
func normalizeLockKeys(keys ...string) []string {
if len(keys) == 0 {
return nil
}
deduped := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
deduped[trimmed] = struct{}{}
}
if len(deduped) == 0 {
return nil
}
normalized := make([]string, 0, len(deduped))
for key := range deduped {
normalized = append(normalized, key)
}
sort.Strings(normalized)
return normalized
}
func advisoryLockHash(key string) int64 {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(key))
return int64(hasher.Sum64())
}
func lockRepositoryScopedKeys(ctx context.Context, client *dbent.Client, exec sqlQueryExecutor, keys ...string) (func(), error) {
release := repositoryScopedKeyLocks.lock(keys...)
normalized := normalizeLockKeys(keys...)
if len(normalized) == 0 || client == nil || exec == nil || client.Driver().Dialect() != dialect.Postgres {
return release, nil
}
for _, key := range normalized {
rows, err := exec.QueryContext(ctx, "SELECT pg_advisory_xact_lock($1)", advisoryLockHash(key))
if err != nil {
release()
return nil, err
}
_ = rows.Close()
}
return release, nil
}
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
if dbent.TxFromContext(ctx) != nil {
return fn(ctx)
......@@ -301,17 +412,18 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
client := clientFromContext(txCtx, r.client)
canonical := input.Canonical
identity, err := client.AuthIdentity.Query().
identityRecords, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
authidentity.ProviderKeyIn(compatibleIdentityProviderKeys(canonical.ProviderType, canonical.ProviderKey)...),
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
).
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
All(txCtx)
if err != nil {
return err
}
if identity != nil && identity.UserID != input.UserID {
identity := selectOwnedCompatibleIdentity(identityRecords, input.UserID)
if identity == nil && hasCompatibleIdentityConflict(identityRecords, input.UserID) {
return ErrAuthIdentityOwnershipConflict
}
if identity == nil {
......@@ -328,7 +440,11 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return err
}
} else {
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(canonical.ProviderType, identity.ProviderKey, canonical.ProviderKey)
update := client.AuthIdentity.UpdateOneID(identity.ID)
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, identity.ProviderKey) {
update = update.SetProviderKey(targetProviderKey)
}
if input.Metadata != nil {
update = update.SetMetadata(copyMetadata(input.Metadata))
}
......@@ -346,20 +462,21 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Query().
channelRecords, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
authidentitychannel.ProviderKeyIn(compatibleIdentityProviderKeys(input.Channel.ProviderType, input.Channel.ProviderKey)...),
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
).
WithIdentity().
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
All(txCtx)
if err != nil {
return err
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
channel = selectOwnedCompatibleChannel(channelRecords, input.UserID)
if channel == nil && hasCompatibleChannelConflict(channelRecords, input.UserID) {
return ErrAuthIdentityChannelOwnershipConflict
}
if channel == nil {
......@@ -376,8 +493,12 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return err
}
} else {
targetProviderKey := canonicalizeCompatibleIdentityProviderKey(input.Channel.ProviderType, channel.ProviderKey, input.Channel.ProviderKey)
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID)
if targetProviderKey != "" && !strings.EqualFold(targetProviderKey, channel.ProviderKey) {
update = update.SetProviderKey(targetProviderKey)
}
if input.ChannelMetadata != nil {
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
}
......@@ -397,6 +518,104 @@ func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindA
return result, nil
}
func compatibleIdentityProviderKeys(providerType, providerKey string) []string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return []string{providerKey}
}
if providerType != "wechat" {
return []string{providerKey}
}
keys := []string{providerKey}
if !strings.EqualFold(providerKey, "wechat-main") {
keys = append(keys, "wechat-main")
}
if !strings.EqualFold(providerKey, "wechat") {
keys = append(keys, "wechat")
}
return keys
}
func canonicalizeCompatibleIdentityProviderKey(providerType, existingKey, requestedKey string) string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
existingKey = strings.TrimSpace(existingKey)
requestedKey = strings.TrimSpace(requestedKey)
if providerType != "wechat" {
if requestedKey != "" {
return requestedKey
}
return existingKey
}
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
return "wechat-main"
}
if requestedKey != "" {
return requestedKey
}
return existingKey
}
func compatibleIdentityProviderKeyRank(providerType, providerKey string) int {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerType != "wechat" {
return 0
}
switch {
case strings.EqualFold(providerKey, "wechat-main"):
return 0
case strings.EqualFold(providerKey, "wechat"):
return 2
default:
return 1
}
}
func selectOwnedCompatibleIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
var selected *dbent.AuthIdentity
for _, record := range records {
if record.UserID != userID {
continue
}
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasCompatibleIdentityConflict(records []*dbent.AuthIdentity, userID int64) bool {
for _, record := range records {
if record.UserID != userID {
return true
}
}
return false
}
func selectOwnedCompatibleChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
var selected *dbent.AuthIdentityChannel
for _, record := range records {
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
continue
}
if selected == nil || compatibleIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < compatibleIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasCompatibleChannelConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
return true
}
}
return false
}
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
......@@ -422,51 +641,70 @@ ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
}
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
client := clientFromContext(ctx, r.client)
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.NEQ(col, input.PendingAuthSessionID),
))
}),
).
ClearIdentityID().
Save(ctx); err != nil {
return nil, err
var result *dbent.IdentityAdoptionDecision
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
releaseLocks, err := lockRepositoryScopedKeys(
txCtx,
client,
txAwareSQLExecutor(txCtx, r.sql, r.client),
identityAdoptionDecisionLockKeys(input.PendingAuthSessionID, input.IdentityID)...,
)
if err != nil {
return err
}
defer releaseLocks()
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.NEQ(col, input.PendingAuthSessionID),
))
}),
).
ClearIdentityID().
Save(txCtx); err != nil {
return err
}
}
}
current, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
now := time.Now().UTC()
if current == nil {
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(now)
if input.IdentityID != nil {
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil && *input.IdentityID > 0 {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
decisionID, err := create.
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
UpdateNewValues().
ID(txCtx)
if err != nil {
return err
}
result, err = client.IdentityAdoptionDecision.Get(txCtx, decisionID)
return err
})
if err != nil {
return nil, err
}
return result, nil
}
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
func identityAdoptionDecisionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
keys := []string{fmt.Sprintf("identity-adoption:pending:%d", pendingAuthSessionID)}
if identityID != nil && *identityID > 0 {
keys = append(keys, fmt.Sprintf("identity-adoption:identity:%d", *identityID))
}
return update.Save(ctx)
return keys
}
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
......
......@@ -10,6 +10,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
......@@ -186,6 +188,79 @@ func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAn
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
}
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_ReusesLegacyWeChatAliasRecords() {
user := s.mustCreateUser("wechat-legacy-alias")
legacyIdentity, err := s.client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetProviderSubject("union-legacy-123").
SetMetadata(map[string]any{"source": "legacy-alias"}).
Save(s.ctx)
s.Require().NoError(err)
legacyChannel, err := s.client.AuthIdentityChannel.Create().
SetIdentityID(legacyIdentity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("oa").
SetChannelAppID("wx-app-legacy").
SetChannelSubject("openid-legacy-123").
SetMetadata(map[string]any{"scene": "legacy-alias"}).
Save(s.ctx)
s.Require().NoError(err)
bound, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-legacy-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
Channel: "oa",
ChannelAppID: "wx-app-legacy",
ChannelSubject: "openid-legacy-123",
},
Metadata: map[string]any{"source": "canonical-bind"},
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
})
s.Require().NoError(err)
s.Require().NotNil(bound)
s.Require().NotNil(bound.Identity)
s.Require().NotNil(bound.Channel)
s.Require().Equal(legacyIdentity.ID, bound.Identity.ID)
s.Require().Equal(legacyChannel.ID, bound.Channel.ID)
s.Require().Equal("wechat-main", bound.Identity.ProviderKey)
s.Require().Equal("wechat-main", bound.Channel.ProviderKey)
s.Require().Equal("canonical-bind", bound.Identity.Metadata["source"])
s.Require().Equal("canonical-bind", bound.Channel.Metadata["scene"])
identityCount, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderSubjectEQ("union-legacy-123"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Equal(1, identityCount)
channelCount, err := s.client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ChannelEQ("oa"),
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Equal(1, channelCount)
}
func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
user := s.mustCreateUser("provider-mismatch-create")
......
package repository
import (
"context"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUserRepositoryBindAuthIdentityToUserCanonicalizesLegacyWeChatAlias(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
user := &service.User{
Email: "wechat-legacy@example.com",
Username: "wechat-legacy",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, user))
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetProviderSubject("union-legacy-123").
SetMetadata(map[string]any{"source": "legacy-alias"}).
Save(ctx)
require.NoError(t, err)
legacyChannel, err := client.AuthIdentityChannel.Create().
SetIdentityID(legacyIdentity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("oa").
SetChannelAppID("wx-app-legacy").
SetChannelSubject("openid-legacy-123").
SetMetadata(map[string]any{"scene": "legacy-alias"}).
Save(ctx)
require.NoError(t, err)
bound, err := repo.BindAuthIdentityToUser(ctx, BindAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-legacy-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
Channel: "oa",
ChannelAppID: "wx-app-legacy",
ChannelSubject: "openid-legacy-123",
},
Metadata: map[string]any{"source": "canonical-bind"},
ChannelMetadata: map[string]any{"scene": "canonical-bind"},
})
require.NoError(t, err)
require.NotNil(t, bound)
require.NotNil(t, bound.Identity)
require.NotNil(t, bound.Channel)
require.Equal(t, legacyIdentity.ID, bound.Identity.ID)
require.Equal(t, legacyChannel.ID, bound.Channel.ID)
require.Equal(t, "wechat-main", bound.Identity.ProviderKey)
require.Equal(t, "wechat-main", bound.Channel.ProviderKey)
reloadedIdentity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", reloadedIdentity.ProviderKey)
require.Equal(t, "canonical-bind", reloadedIdentity.Metadata["source"])
reloadedChannel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", reloadedChannel.ProviderKey)
require.Equal(t, "canonical-bind", reloadedChannel.Metadata["scene"])
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderSubjectEQ("union-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ChannelEQ("oa"),
authidentitychannel.ChannelAppIDEQ("wx-app-legacy"),
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, channelCount)
}
func TestUserRepositoryUpsertIdentityAdoptionDecisionIsIdempotentUnderConcurrency(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
user := &service.User{
Email: "repo-adoption@example.com",
Username: "repo-adoption",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, user))
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-repo-adoption").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("pending-repo-adoption").
SetIntent("bind_current_user").
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("union-repo-adoption").
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
SetUpstreamIdentityClaims(map[string]any{"provider_subject": "union-repo-adoption"}).
SetLocalFlowState(map[string]any{"step": "pending"}).
Save(ctx)
require.NoError(t, err)
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.IdentityAdoptionDecision.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type adoptionResult struct {
decision *dbent.IdentityAdoptionDecision
err error
}
input := IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
}
results := make(chan adoptionResult, 2)
go func() {
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
<-firstCreateStarted
go func() {
decision, err := repo.UpsertIdentityAdoptionDecision(ctx, input)
results <- adoptionResult{decision: decision, err: err}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
require.NoError(t, first.err)
require.NoError(t, second.err)
require.NotNil(t, first.decision)
require.NotNil(t, second.decision)
require.Equal(t, first.decision.ID, second.decision.ID)
count, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
loaded, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, loaded.IdentityID)
require.Equal(t, identity.ID, *loaded.IdentityID)
require.True(t, loaded.AdoptDisplayName)
require.True(t, loaded.AdoptAvatar)
}
......@@ -52,9 +52,11 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
......@@ -64,6 +66,21 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
}
}
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, 0, userIn.Email); err != nil {
return err
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
......@@ -76,15 +93,15 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx)
Save(txCtx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
if err := ensureEmailAuthIdentityWithClient(txCtx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
......@@ -154,9 +171,11 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
}
var txClient *dbent.Client
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else {
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
......@@ -165,7 +184,23 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client
}
}
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
releaseEmailLock, err := lockRepositoryScopedKeys(
txCtx,
txClient,
txAwareSQLExecutor(txCtx, r.sql, r.client),
normalizedEmailUniquenessLockKey(userIn.Email),
)
if err != nil {
return err
}
defer releaseEmailLock()
if err := ensureNormalizedEmailAvailableWithClient(txCtx, txClient, userIn.ID, userIn.Email); err != nil {
return err
}
existing, err := clientFromContext(txCtx, txClient).User.Get(txCtx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
......@@ -197,15 +232,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
updated, err := updateOp.Save(ctx)
updated, err := updateOp.Save(txCtx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
if err := r.syncUserAllowedGroupsWithClient(txCtx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
if err := replaceEmailAuthIdentityWithClient(txCtx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
......@@ -704,8 +739,28 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
}
func ensureNormalizedEmailAvailableWithClient(ctx context.Context, client *dbent.Client, userID int64, email string) error {
client = clientFromContext(ctx, client)
if client == nil {
return nil
}
matches, err := client.User.Query().
Where(userEmailLookupPredicate(email)).
All(ctx)
if err != nil {
return err
}
for _, match := range matches {
if match.ID != userID {
return service.ErrEmailExists
}
}
return nil
}
func userEmailLookupPredicate(email string) predicate.User {
normalized := strings.ToLower(strings.TrimSpace(email))
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return dbuser.EmailEQ(email)
}
......@@ -719,6 +774,18 @@ func userEmailLookupPredicate(email string) predicate.User {
})
}
func normalizeEmailLookupValue(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
func normalizedEmailUniquenessLockKey(email string) string {
normalized := normalizeEmailLookupValue(email)
if normalized == "" {
return ""
}
return "users:normalized-email:" + normalized
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
err := client.UserAllowedGroup.Create().
......@@ -853,11 +920,14 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
}
func userSignupSourceOrDefault(signupSource string) string {
signupSource = strings.TrimSpace(signupSource)
if signupSource == "" {
switch strings.TrimSpace(strings.ToLower(signupSource)) {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email"
}
return signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
......
......@@ -3,7 +3,10 @@ package repository
import (
"context"
"database/sql"
"fmt"
"sync"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
......@@ -18,9 +21,10 @@ import (
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared")
db, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", t.Name()))
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
db.SetMaxOpenConns(10)
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
......@@ -67,3 +71,157 @@ func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T)
require.NoError(t, err)
require.True(t, exists)
}
func TestUserRepositoryCreateRejectsNormalizedEmailDuplicate(t *testing.T) {
repo, _ := newUserEntRepo(t)
ctx := context.Background()
err := repo.Create(ctx, &service.User{
Email: " Existing@Example.com ",
Username: "existing-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
require.NoError(t, err)
err = repo.Create(ctx, &service.User{
Email: "existing@example.com",
Username: "duplicate-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
require.ErrorIs(t, err, service.ErrEmailExists)
}
func TestUserRepositoryUpdateRejectsNormalizedEmailDuplicate(t *testing.T) {
repo, _ := newUserEntRepo(t)
ctx := context.Background()
first := &service.User{
Email: " Existing@Example.com ",
Username: "existing-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, first))
second := &service.User{
Email: "second@example.com",
Username: "second-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, second))
second.Email = " existing@example.com "
err := repo.Update(ctx, second)
require.ErrorIs(t, err, service.ErrEmailExists)
}
func TestUserRepositoryGetByEmailReportsNormalizedEmailConflict(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
_, err := client.User.Create().
SetEmail("Conflict@Example.com").
SetUsername("conflict-user-1").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.User.Create().
SetEmail(" conflict@example.com ").
SetUsername("conflict-user-2").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = repo.GetByEmail(ctx, "conflict@example.com")
require.Error(t, err)
require.ErrorContains(t, err, "normalized email lookup matched multiple users")
}
func TestUserRepositoryCreateSerializesNormalizedEmailConflictsUnderConcurrency(t *testing.T) {
repo, client := newUserEntRepo(t)
ctx := context.Background()
firstCreateStarted := make(chan struct{})
releaseFirstCreate := make(chan struct{})
var firstCreate sync.Once
client.User.Use(func(next dbent.Mutator) dbent.Mutator {
return dbent.MutateFunc(func(ctx context.Context, m dbent.Mutation) (dbent.Value, error) {
blocked := false
if m.Op().Is(dbent.OpCreate) {
firstCreate.Do(func() {
blocked = true
close(firstCreateStarted)
})
}
if blocked {
<-releaseFirstCreate
}
return next.Mutate(ctx, m)
})
})
type createResult struct {
err error
}
results := make(chan createResult, 2)
go func() {
results <- createResult{err: repo.Create(ctx, &service.User{
Email: " Race@Example.com ",
Username: "race-user-1",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})}
}()
<-firstCreateStarted
go func() {
results <- createResult{err: repo.Create(ctx, &service.User{
Email: "race@example.com",
Username: "race-user-2",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})}
}()
time.Sleep(100 * time.Millisecond)
close(releaseFirstCreate)
first := <-results
second := <-results
errors := []error{first.err, second.err}
successes := 0
conflicts := 0
for _, err := range errors {
switch err {
case nil:
successes++
case service.ErrEmailExists:
conflicts++
default:
t.Fatalf("unexpected create error: %v", err)
}
}
require.Equal(t, 1, successes)
require.Equal(t, 1, conflicts)
count, err := client.User.Query().Where(userEmailLookupPredicate("race@example.com")).Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, count)
}
......@@ -85,7 +85,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
......@@ -93,7 +93,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
......@@ -101,7 +101,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"identity_bindings": {
......@@ -122,7 +122,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
......@@ -130,7 +130,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
......@@ -138,7 +138,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"auth_bindings": {
......@@ -159,7 +159,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
......@@ -167,7 +167,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/oidc/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
......@@ -175,7 +175,7 @@ func TestAPIContracts(t *testing.T) {
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"run_mode": "standard"
......@@ -784,6 +784,198 @@ func TestAPIContracts(t *testing.T) {
}
}`,
},
{
name: "GET /api/v1/admin/settings falls back to config oauth defaults",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
deps.cfg.OIDC = config.OIDCConnectConfig{
Enabled: true,
ProviderName: "ConfigOIDC",
ClientID: "oidc-config-client",
ClientSecret: "oidc-config-secret",
IssuerURL: "https://issuer.example.com",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
Scopes: "openid email profile",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256,ES256,PS256",
ClockSkewSeconds: 120,
}
deps.cfg.WeChat = config.WeChatConnectConfig{
Enabled: true,
OpenEnabled: true,
OpenAppID: "wx-open-config",
OpenAppSecret: "wx-open-secret",
Mode: "open",
Scopes: "snsapi_login",
FrontendRedirectURL: "/auth/wechat/callback",
}
deps.settingRepo.SetAll(map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyEmailVerifyEnabled: "false",
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
})
},
method: http.MethodGet,
path: "/api/v1/admin/settings",
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"registration_enabled": true,
"email_verify_enabled": false,
"registration_email_suffix_whitelist": [],
"promo_code_enabled": true,
"password_reset_enabled": false,
"frontend_url": "",
"invitation_code_enabled": false,
"totp_enabled": false,
"totp_encryption_key_configured": false,
"smtp_host": "",
"smtp_port": 587,
"smtp_username": "",
"smtp_password_configured": false,
"smtp_from_email": "",
"smtp_from_name": "",
"smtp_use_tls": false,
"turnstile_enabled": false,
"turnstile_site_key": "",
"turnstile_secret_key_configured": false,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"oidc_connect_enabled": true,
"oidc_connect_provider_name": "ConfigOIDC",
"oidc_connect_client_id": "oidc-config-client",
"oidc_connect_client_secret_configured": true,
"oidc_connect_issuer_url": "https://issuer.example.com",
"oidc_connect_discovery_url": "",
"oidc_connect_authorize_url": "",
"oidc_connect_token_url": "",
"oidc_connect_userinfo_url": "",
"oidc_connect_jwks_url": "",
"oidc_connect_scopes": "openid email profile",
"oidc_connect_redirect_url": "https://api.example.com/api/v1/auth/oauth/oidc/callback",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
"oidc_connect_require_email_verified": false,
"oidc_connect_userinfo_email_path": "",
"oidc_connect_userinfo_id_path": "",
"oidc_connect_userinfo_username_path": "",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subscription to API Conversion Platform",
"api_base_url": "",
"contact_info": "",
"doc_url": "",
"home_content": "",
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": "",
"table_default_page_size": 20,
"table_page_size_options": [10, 20, 50],
"custom_menu_items": [],
"custom_endpoints": [],
"default_concurrency": 0,
"default_balance": 0,
"default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_openai": "gpt-4o",
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_antigravity": "gemini-2.5-pro",
"enable_identity_patch": true,
"identity_patch_prompt": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
"ops_metrics_interval_seconds": 60,
"min_claude_code_version": "",
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"enable_cch_signing": false,
"web_search_emulation_enabled": false,
"payment_visible_method_alipay_source": "",
"payment_visible_method_wxpay_source": "",
"payment_visible_method_alipay_enabled": false,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"payment_enabled": false,
"payment_min_amount": 0,
"payment_max_amount": 0,
"payment_daily_limit": 0,
"payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false,
"payment_balance_recharge_multiplier": 0,
"payment_recharge_fee_rate": 0,
"payment_load_balance_strategy": "",
"payment_product_name_prefix": "",
"payment_product_name_suffix": "",
"payment_help_image_url": "",
"payment_help_text": "",
"payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
"wechat_connect_enabled": true,
"wechat_connect_app_id": "wx-open-config",
"wechat_connect_app_secret_configured": true,
"wechat_connect_mode": "open",
"wechat_connect_open_enabled": true,
"wechat_connect_open_app_id": "wx-open-config",
"wechat_connect_open_app_secret_configured": true,
"wechat_connect_mp_enabled": false,
"wechat_connect_mp_app_id": "wx-open-config",
"wechat_connect_mp_app_secret_configured": true,
"wechat_connect_mobile_enabled": false,
"wechat_connect_mobile_app_id": "wx-open-config",
"wechat_connect_mobile_app_secret_configured": true,
"wechat_connect_redirect_url": "",
"wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
"wechat_connect_scopes": "snsapi_login",
"auth_source_default_email_balance": 0,
"auth_source_default_email_concurrency": 5,
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
"auth_source_default_linuxdo_grant_on_signup": false,
"auth_source_default_linuxdo_grant_on_first_bind": false,
"auth_source_default_oidc_balance": 0,
"auth_source_default_oidc_concurrency": 5,
"auth_source_default_oidc_subscriptions": [],
"auth_source_default_oidc_grant_on_signup": false,
"auth_source_default_oidc_grant_on_first_bind": false,
"auth_source_default_wechat_balance": 0,
"auth_source_default_wechat_concurrency": 5,
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
"force_email_on_third_party_signup": false
}
}`,
},
{
name: "POST /api/v1/admin/accounts/bulk-update",
method: http.MethodPost,
......@@ -827,6 +1019,7 @@ func TestAPIContracts(t *testing.T) {
type contractDeps struct {
now time.Time
router http.Handler
cfg *config.Config
apiKeyRepo *stubApiKeyRepo
groupRepo *stubGroupRepo
userSubRepo *stubUserSubscriptionRepo
......@@ -947,6 +1140,7 @@ func newContractDeps(t *testing.T) *contractDeps {
return &contractDeps{
now: now,
router: r,
cfg: cfg,
apiKeyRepo: apiKeyRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
......
......@@ -27,23 +27,50 @@ func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFun
}
}
func backendModeAllowsAuthPath(path string) bool {
path = strings.ToLower(strings.TrimSpace(path))
for _, suffix := range []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"} {
if strings.HasSuffix(path, suffix) {
return true
}
}
for _, suffix := range []string{
"/auth/oauth/linuxdo/callback",
"/auth/oauth/wechat/callback",
"/auth/oauth/wechat/payment/callback",
"/auth/oauth/oidc/callback",
"/auth/oauth/linuxdo/complete-registration",
"/auth/oauth/wechat/complete-registration",
"/auth/oauth/oidc/complete-registration",
"/auth/oauth/linuxdo/create-account",
"/auth/oauth/wechat/create-account",
"/auth/oauth/oidc/create-account",
"/auth/oauth/linuxdo/bind-login",
"/auth/oauth/wechat/bind-login",
"/auth/oauth/oidc/bind-login",
} {
if strings.HasSuffix(path, suffix) {
return true
}
}
return strings.Contains(path, "/auth/oauth/pending/")
}
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
// Allows: login, login/2fa, logout, refresh (admin needs these).
// Blocks: register, forgot-password, reset-password, OAuth, etc.
// Allows the minimal auth surface admins still need in backend mode, including
// OAuth callbacks and pending continuations. Handler-level backend mode checks
// still enforce admin-only login and forbid self-service registration.
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
return func(c *gin.Context) {
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
c.Next()
return
}
path := c.Request.URL.Path
// Allow login, 2FA, logout, refresh, public settings
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
for _, suffix := range allowedSuffixes {
if strings.HasSuffix(path, suffix) {
c.Next()
return
}
if backendModeAllowsAuthPath(c.Request.URL.Path) {
c.Next()
return
}
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
c.Abort()
......
......@@ -198,6 +198,96 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/refresh",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_linuxdo_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_linuxdo_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_wechat_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_wechat_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_wechat_payment_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/payment/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_wechat_payment_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/payment/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_oidc_oauth_start",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/start",
wantStatus: http.StatusForbidden,
},
{
name: "enabled_allows_oidc_oauth_callback",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/callback",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_exchange",
enabled: "true",
path: "/api/v1/auth/oauth/pending/exchange",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_send_verify_code",
enabled: "true",
path: "/api/v1/auth/oauth/pending/send-verify-code",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_create_account",
enabled: "true",
path: "/api/v1/auth/oauth/pending/create-account",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_oauth_pending_bind_login",
enabled: "true",
path: "/api/v1/auth/oauth/pending/bind-login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_provider_bind_login",
enabled: "true",
path: "/api/v1/auth/oauth/oidc/bind-login",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_provider_create_account",
enabled: "true",
path: "/api/v1/auth/oauth/wechat/create-account",
wantStatus: http.StatusOK,
},
{
name: "enabled_allows_legacy_complete_registration",
enabled: "true",
path: "/api/v1/auth/oauth/linuxdo/complete-registration",
wantStatus: http.StatusOK,
},
{
name: "enabled_blocks_register",
enabled: "true",
......
......@@ -63,8 +63,20 @@ func RegisterAuthRoutes(
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.LinuxDoOAuthStart(c)
})
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
auth.GET("/oauth/wechat/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.WeChatOAuthStart(c)
})
auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart)
auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback)
......@@ -129,6 +141,12 @@ func RegisterAuthRoutes(
h.Auth.CreateWeChatOAuthAccount,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.OIDCOAuthStart(c)
})
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
rateLimiter.LimitWithOptions("oauth-oidc-complete", 10, time.Minute, middleware.RateLimitOptions{
......@@ -164,23 +182,6 @@ func RegisterAuthRoutes(
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.LinuxDoOAuthStart(c)
})
authenticated.GET("/auth/oauth/oidc/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.OIDCOAuthStart(c)
})
authenticated.GET("/auth/oauth/wechat/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.WeChatOAuthStart(c)
})
authenticated.POST("/auth/oauth/bind-token", h.Auth.PrepareOAuthBindAccessTokenCookie)
}
}
......@@ -44,9 +44,9 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
// Signed resume-token recovery is the supported public lookup path.
// The legacy anonymous out_trade_no verify endpoint is kept only as a
// compatibility shim that returns HTTP 410 Gone.
// Signed resume-token recovery is the preferred public lookup path.
// The legacy anonymous out_trade_no verify endpoint remains available as a
// persisted-state compatibility path for staggered upgrades.
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
......
......@@ -419,6 +419,7 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string) error {
ctx := c.Request.Context()
_ = prompt
// Default to openai.DefaultTestModel for OpenAI testing
testModelID := modelID
......
......@@ -879,6 +879,8 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
}
canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey)
compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey)
var issuer *string
if input.Issuer != nil {
......@@ -900,25 +902,26 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
}
defer func() { _ = tx.Rollback() }()
identity, err := tx.AuthIdentity.Query().
identityRecords, err := tx.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderKeyIn(compatibleProviderKeys...),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
All(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if identity != nil && identity.UserID != userID {
if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
identity := selectOwnedAdminAuthIdentity(identityRecords, userID)
if identity == nil {
create := tx.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderKey(canonicalProviderKey).
SetProviderSubject(providerSubject).
SetVerifiedAt(verifiedAt)
if issuer != nil {
......@@ -932,7 +935,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
}
} else {
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
update := tx.AuthIdentity.UpdateOneID(identity.ID).
SetVerifiedAt(verifiedAt).
SetProviderKey(canonicalProviderKey)
if issuer != nil {
update = update.SetIssuer(*issuer)
}
......@@ -947,27 +952,28 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
var channel *dbent.AuthIdentityChannel
if channelInput != nil {
channel, err = tx.AuthIdentityChannel.Query().
channelRecords, err := tx.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(providerType),
authidentitychannel.ProviderKeyEQ(providerKey),
authidentitychannel.ProviderKeyIn(compatibleProviderKeys...),
authidentitychannel.ChannelEQ(channelInput.Channel),
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
).
WithIdentity().
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
All(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) {
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
}
channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID)
if channel == nil {
create := tx.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderKey(canonicalProviderKey).
SetChannel(channelInput.Channel).
SetChannelAppID(channelInput.ChannelAppID).
SetChannelSubject(channelInput.ChannelSubject)
......@@ -979,7 +985,9 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
}
} else {
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID).
SetProviderKey(canonicalProviderKey)
if channelInput.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
}
......@@ -996,6 +1004,105 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
return buildAdminBoundAuthIdentity(identity, channel), nil
}
func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerKey == "" {
return []string{providerKey}
}
if providerType != "wechat" {
return []string{providerKey}
}
keys := []string{providerKey}
if !strings.EqualFold(providerKey, "wechat-main") {
keys = append(keys, "wechat-main")
}
if !strings.EqualFold(providerKey, "wechat") {
keys = append(keys, "wechat")
}
return keys
}
func canonicalAdminAuthIdentityProviderKey(providerType, existingKey, requestedKey string) string {
providerType = strings.TrimSpace(strings.ToLower(providerType))
existingKey = strings.TrimSpace(existingKey)
requestedKey = strings.TrimSpace(requestedKey)
if providerType != "wechat" {
if requestedKey != "" {
return requestedKey
}
return existingKey
}
if strings.EqualFold(existingKey, "wechat") || strings.EqualFold(existingKey, "wechat-main") || strings.EqualFold(requestedKey, "wechat-main") {
return "wechat-main"
}
if requestedKey != "" {
return requestedKey
}
return existingKey
}
func adminAuthIdentityProviderKeyRank(providerType, providerKey string) int {
providerType = strings.TrimSpace(strings.ToLower(providerType))
providerKey = strings.TrimSpace(providerKey)
if providerType != "wechat" {
return 0
}
switch {
case strings.EqualFold(providerKey, "wechat-main"):
return 0
case strings.EqualFold(providerKey, "wechat"):
return 2
default:
return 1
}
}
func selectOwnedAdminAuthIdentity(records []*dbent.AuthIdentity, userID int64) *dbent.AuthIdentity {
var selected *dbent.AuthIdentity
for _, record := range records {
if record.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityOwnershipConflict(records []*dbent.AuthIdentity, userID int64) bool {
for _, record := range records {
if record.UserID != userID {
return true
}
}
return false
}
func selectOwnedAdminAuthIdentityChannel(records []*dbent.AuthIdentityChannel, userID int64) *dbent.AuthIdentityChannel {
var selected *dbent.AuthIdentityChannel
for _, record := range records {
if record.Edges.Identity == nil || record.Edges.Identity.UserID != userID {
continue
}
if selected == nil || adminAuthIdentityProviderKeyRank(record.ProviderType, record.ProviderKey) < adminAuthIdentityProviderKeyRank(selected.ProviderType, selected.ProviderKey) {
selected = record
}
}
return selected
}
func hasAdminAuthIdentityChannelOwnershipConflict(records []*dbent.AuthIdentityChannel, userID int64) bool {
for _, record := range records {
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
return true
}
}
return false
}
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
if input == nil {
return nil
......
......@@ -188,6 +188,93 @@ func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
require.Equal(t, "second", identities[0].Metadata["source"])
}
func TestAdminServiceBindUserAuthIdentityReusesLegacyWeChatAliasRecords(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("wechat-alias@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
legacyIdentity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetProviderSubject("union-legacy-123").
SetMetadata(map[string]any{"source": "legacy"}).
Save(ctx)
require.NoError(t, err)
legacyChannel, err := client.AuthIdentityChannel.Create().
SetIdentityID(legacyIdentity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("open").
SetChannelAppID("wx-open").
SetChannelSubject("openid-legacy-123").
SetMetadata(map[string]any{"scene": "legacy"}).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-legacy-123",
Metadata: map[string]any{"source": "admin-repair"},
Channel: &AdminBindAuthIdentityChannelInput{
Channel: "open",
ChannelAppID: "wx-open",
ChannelSubject: "openid-legacy-123",
Metadata: map[string]any{"scene": "admin-repair"},
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "wechat-main", result.ProviderKey)
require.NotNil(t, result.Channel)
require.Equal(t, "open", result.Channel.Channel)
identity, err := client.AuthIdentity.Get(ctx, legacyIdentity.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", identity.ProviderKey)
require.Equal(t, "admin-repair", identity.Metadata["source"])
channel, err := client.AuthIdentityChannel.Get(ctx, legacyChannel.ID)
require.NoError(t, err)
require.Equal(t, "wechat-main", channel.ProviderKey)
require.Equal(t, legacyIdentity.ID, channel.IdentityID)
require.Equal(t, "admin-repair", channel.Metadata["scene"])
identityCount, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderSubjectEQ("union-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, identityCount)
channelCount, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open"),
authidentitychannel.ChannelSubjectEQ("openid-legacy-123"),
).
Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, channelCount)
}
func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
......
......@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// BindEmailIdentity verifies and binds a local email/password identity to the
......@@ -69,6 +70,7 @@ func (s *AuthService) BindEmailIdentity(
if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
return nil, err
}
s.revokeEmailIdentitySessions(ctx, userID)
return currentUser, nil
}
......@@ -87,6 +89,7 @@ func (s *AuthService) BindEmailIdentity(
}
}
s.revokeEmailIdentitySessions(ctx, userID)
return currentUser, nil
}
......@@ -219,6 +222,12 @@ func (s *AuthService) updateBoundEmailIdentityWithClient(
return nil
}
func (s *AuthService) revokeEmailIdentitySessions(ctx context.Context, userID int64) {
if err := s.RevokeAllUserSessions(ctx, userID); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh sessions after email identity bind for user %d: %v", userID, err)
}
}
func replaceBoundEmailAuthIdentityWithClient(
ctx context.Context,
client *dbent.Client,
......
......@@ -14,10 +14,14 @@ import (
func normalizeOAuthSignupSource(signupSource string) string {
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
switch signupSource {
case "", "email":
return "email"
case "linuxdo", "wechat", "oidc":
return signupSource
default:
return "email"
}
return signupSource
}
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
......@@ -136,10 +140,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
signupSource = "email"
}
signupSource = normalizeOAuthSignupSource(signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
user := &User{
......@@ -149,6 +150,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
SignupSource: signupSource,
}
if err := s.userRepo.Create(ctx, user); err != nil {
......
......@@ -191,6 +191,80 @@ func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFai
require.Empty(t, redeemRepo.updateCalls)
}
func TestRegisterOAuthEmailAccountSetsNormalizedSignupSourceOnCreatedUser(t *testing.T) {
userRepo := &userRepoStub{nextID: 42}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fresh@example.com",
"secret-123",
"246810",
"",
" OIDC ",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "oidc", userRepo.created[0].SignupSource)
}
func TestRegisterOAuthEmailAccountFallsBackUnknownSignupSourceToEmail(t *testing.T) {
userRepo := &userRepoStub{nextID: 43}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fallback@example.com",
"secret-123",
"246810",
"",
"github",
)
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
require.Len(t, userRepo.created, 1)
require.Equal(t, "email", userRepo.created[0].SignupSource)
}
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
userRepo := &userRepoStub{}
redeemRepo := &redeemCodeRepoStub{
......
......@@ -5,10 +5,15 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash/fnv"
"sort"
"strings"
"sync"
"time"
"entgo.io/ent/dialect"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
......@@ -75,6 +80,122 @@ type AuthPendingIdentityService struct {
entClient *dbent.Client
}
var authPendingIdentityScopedKeyLocks = newAuthPendingIdentityScopedKeyLockRegistry()
type authPendingIdentityScopedKeyLockRegistry struct {
mu sync.Mutex
locks map[string]*authPendingIdentityScopedKeyLockEntry
}
type authPendingIdentityScopedKeyLockEntry struct {
mu sync.Mutex
refs int
}
func newAuthPendingIdentityScopedKeyLockRegistry() *authPendingIdentityScopedKeyLockRegistry {
return &authPendingIdentityScopedKeyLockRegistry{
locks: make(map[string]*authPendingIdentityScopedKeyLockEntry),
}
}
func (r *authPendingIdentityScopedKeyLockRegistry) lock(keys ...string) func() {
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 {
return func() {}
}
entries := make([]*authPendingIdentityScopedKeyLockEntry, 0, len(normalized))
r.mu.Lock()
for _, key := range normalized {
entry := r.locks[key]
if entry == nil {
entry = &authPendingIdentityScopedKeyLockEntry{}
r.locks[key] = entry
}
entry.refs++
entries = append(entries, entry)
}
r.mu.Unlock()
for _, entry := range entries {
entry.mu.Lock()
}
return func() {
for i := len(entries) - 1; i >= 0; i-- {
entries[i].mu.Unlock()
}
r.mu.Lock()
defer r.mu.Unlock()
for idx, key := range normalized {
entry := entries[idx]
entry.refs--
if entry.refs == 0 {
delete(r.locks, key)
}
}
}
}
func normalizeAuthPendingIdentityLockKeys(keys ...string) []string {
if len(keys) == 0 {
return nil
}
deduped := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
deduped[trimmed] = struct{}{}
}
if len(deduped) == 0 {
return nil
}
normalized := make([]string, 0, len(deduped))
for key := range deduped {
normalized = append(normalized, key)
}
sort.Strings(normalized)
return normalized
}
func authPendingIdentityAdvisoryLockHash(key string) int64 {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(key))
return int64(hasher.Sum64())
}
func lockAuthPendingIdentityKeys(ctx context.Context, client *dbent.Client, keys ...string) (func(), error) {
release := authPendingIdentityScopedKeyLocks.lock(keys...)
normalized := normalizeAuthPendingIdentityLockKeys(keys...)
if len(normalized) == 0 || client == nil || client.Driver().Dialect() != dialect.Postgres {
return release, nil
}
for _, key := range normalized {
var rows entsql.Rows
if err := client.Driver().Query(ctx, "SELECT pg_advisory_xact_lock($1)", []any{authPendingIdentityAdvisoryLockHash(key)}, &rows); err != nil {
release()
return nil, err
}
_ = rows.Close()
}
return release, nil
}
func pendingIdentityAdoptionLockKeys(pendingAuthSessionID int64, identityID *int64) []string {
keys := []string{fmt.Sprintf("pending-auth-adoption:pending:%d", pendingAuthSessionID)}
if identityID != nil && *identityID > 0 {
keys = append(keys, fmt.Sprintf("pending-auth-adoption:identity:%d", *identityID))
}
return keys
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient}
}
......@@ -236,16 +357,66 @@ func (s *AuthPendingIdentityService) consumeSession(
return nil, err
}
sanitizedLocalFlowState := sanitizePendingAuthLocalFlowState(session.LocalFlowState)
now := time.Now().UTC()
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
Where(
pendingauthsession.ConsumedAtIsNil(),
pendingauthsession.ExpiresAtGTE(now),
pendingauthsession.Or(
pendingauthsession.CompletionCodeExpiresAtIsNil(),
pendingauthsession.CompletionCodeExpiresAtGTE(now),
),
).
SetConsumedAt(now).
SetLocalFlowState(sanitizedLocalFlowState).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx)
if err != nil {
ClearCompletionCodeExpiresAt()
if expectedBrowserSessionKey := strings.TrimSpace(session.BrowserSessionKey); expectedBrowserSessionKey != "" {
update = update.Where(pendingauthsession.BrowserSessionKeyEQ(expectedBrowserSessionKey))
}
updated, err := update.Save(ctx)
if err == nil {
return updated, nil
}
if !dbent.IsNotFound(err) {
return nil, err
}
current, currentErr := s.entClient.PendingAuthSession.Get(ctx, session.ID)
if currentErr != nil {
if dbent.IsNotFound(currentErr) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, currentErr
}
if err := validatePendingSessionState(current, browserSessionKey, expiredErr, consumedErr); err != nil {
return nil, err
}
return updated, nil
return nil, consumedErr
}
func sanitizePendingAuthLocalFlowState(localFlowState map[string]any) map[string]any {
sanitized := copyPendingMap(localFlowState)
if len(sanitized) == 0 {
return sanitized
}
rawCompletion, ok := sanitized["completion_response"]
if !ok {
return sanitized
}
completion, ok := rawCompletion.(map[string]any)
if !ok {
return sanitized
}
cleanedCompletion := copyPendingMap(completion)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(cleanedCompletion, key)
}
sanitized["completion_response"] = cleanedCompletion
return sanitized
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
......@@ -274,8 +445,29 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
return nil, fmt.Errorf("pending auth ent client is not configured")
}
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
client := s.entClient
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
client = tx.Client()
txCtx = dbent.NewTxContext(ctx, tx)
} else if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
client = existingTx.Client()
}
releaseLocks, err := lockAuthPendingIdentityKeys(txCtx, client, pendingIdentityAdoptionLockKeys(input.PendingAuthSessionID, input.IdentityID)...)
if err != nil {
return nil, err
}
defer releaseLocks()
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := s.entClient.IdentityAdoptionDecision.Update().
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
......@@ -287,36 +479,40 @@ func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context,
}),
).
ClearIdentityID().
Save(ctx); err != nil {
Save(txCtx); err != nil {
return nil, err
}
}
existing, err := s.entClient.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil && *input.IdentityID > 0 {
create = create.SetIdentityID(*input.IdentityID)
}
decisionID, err := create.
OnConflictColumns(identityadoptiondecision.FieldPendingAuthSessionID).
UpdateNewValues().
ID(txCtx)
if err != nil {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
decision, err := client.IdentityAdoptionDecision.Get(txCtx, decisionID)
if err != nil {
return nil, err
}
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}
return update.Save(ctx)
return decision, nil
}
func copyPendingMap(in map[string]any) map[string]any {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment