Unverified Commit 8eb3f9e7 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1785 from IanShaw027/rebuild/auth-identity-foundation

feat(auth,payment): 重构认证身份和支付系统及其他部分优化
parents 78f691d2 7fbd5177
......@@ -3,14 +3,21 @@
package provider
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/wechatpay-apiv3/wechatpay-go/core"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/jsapi"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
)
// generateTestKeyPair returns a fresh RSA 2048 key pair as PEM strings.
......@@ -120,6 +127,33 @@ func TestWxSV(t *testing.T) {
}
}
func TestBuildWxpayTransactionMetadata(t *testing.T) {
t.Parallel()
tx := &payments.Transaction{
Appid: strPtr("wx-app-id"),
Mchid: strPtr("mch-id"),
TradeState: strPtr(wxpayTradeStateSuccess),
Amount: &payments.TransactionAmount{
Currency: strPtr(wxpayCurrency),
},
}
metadata := buildWxpayTransactionMetadata(tx)
if metadata[wxpayMetadataAppID] != "wx-app-id" {
t.Fatalf("appid = %q", metadata[wxpayMetadataAppID])
}
if metadata[wxpayMetadataMerchantID] != "mch-id" {
t.Fatalf("mchid = %q", metadata[wxpayMetadataMerchantID])
}
if metadata[wxpayMetadataCurrency] != wxpayCurrency {
t.Fatalf("currency = %q", metadata[wxpayMetadataCurrency])
}
if metadata[wxpayMetadataTradeState] != wxpayTradeStateSuccess {
t.Fatalf("trade_state = %q", metadata[wxpayMetadataTradeState])
}
}
func strPtr(s string) *string {
return &s
}
......@@ -300,3 +334,310 @@ func TestNewWxpay(t *testing.T) {
})
}
}
func TestBuildWxpayResultURLPreservesResumeToken(t *testing.T) {
t.Parallel()
resultURL, err := buildWxpayResultURL("https://app.example.com/payment/result?order_id=42&resume_token=resume-42&status=success", payment.CreatePaymentRequest{
OrderID: "sub2_42",
PaymentType: payment.TypeWxpay,
})
if err != nil {
t.Fatalf("buildWxpayResultURL returned error: %v", err)
}
parsed, err := url.Parse(resultURL)
if err != nil {
t.Fatalf("url.Parse returned error: %v", err)
}
query := parsed.Query()
if parsed.Path != wxpayResultPath {
t.Fatalf("path = %q, want %q", parsed.Path, wxpayResultPath)
}
if query.Get("resume_token") != "resume-42" {
t.Fatalf("resume_token = %q, want %q", query.Get("resume_token"), "resume-42")
}
if query.Get("order_id") != "42" {
t.Fatalf("order_id = %q, want %q", query.Get("order_id"), "42")
}
if query.Get("out_trade_no") != "sub2_42" {
t.Fatalf("out_trade_no = %q, want %q", query.Get("out_trade_no"), "sub2_42")
}
}
func TestResolveWxpayJSAPIAppID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
config map[string]string
want string
}{
{
name: "prefers dedicated mp app id",
config: map[string]string{
"mpAppId": "wx-mp-app",
"appId": "wx-merchant-app",
},
want: "wx-mp-app",
},
{
name: "falls back to merchant app id",
config: map[string]string{
"appId": "wx-merchant-app",
},
want: "wx-merchant-app",
},
{
name: "missing app ids returns empty",
config: map[string]string{},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := ResolveWxpayJSAPIAppID(tt.config); got != tt.want {
t.Fatalf("ResolveWxpayJSAPIAppID() = %q, want %q", got, tt.want)
}
})
}
}
func TestResolveWxpayCreateMode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
req payment.CreatePaymentRequest
wantMode string
wantErr string
}{
{
name: "desktop uses native",
req: payment.CreatePaymentRequest{},
wantMode: wxpayModeNative,
},
{
name: "mobile uses h5 when client ip is present",
req: payment.CreatePaymentRequest{
IsMobile: true,
ClientIP: "203.0.113.10",
},
wantMode: wxpayModeH5,
},
{
name: "mobile without client ip returns clear error",
req: payment.CreatePaymentRequest{
IsMobile: true,
},
wantErr: "requires client IP",
},
{
name: "openid uses jsapi mode",
req: payment.CreatePaymentRequest{
OpenID: "openid-123",
},
wantMode: wxpayModeJSAPI,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := resolveWxpayCreateMode(tt.req)
if tt.wantErr != "" {
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("error %q should contain %q", err.Error(), tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.wantMode {
t.Fatalf("resolveWxpayCreateMode() = %q, want %q", got, tt.wantMode)
}
})
}
}
func TestCreatePaymentWithOpenIDReturnsJSAPIResult(t *testing.T) {
origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
origNativePrepay := wxpayNativePrepay
origH5Prepay := wxpayH5Prepay
t.Cleanup(func() {
wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
wxpayNativePrepay = origNativePrepay
wxpayH5Prepay = origH5Prepay
})
jsapiCalls := 0
nativeCalls := 0
h5Calls := 0
wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
jsapiCalls++
if got := wxSV(req.Payer.Openid); got != "openid-123" {
t.Fatalf("openid = %q, want %q", got, "openid-123")
}
if req.SceneInfo == nil || wxSV(req.SceneInfo.PayerClientIp) != "203.0.113.10" {
t.Fatalf("scene_info payer_client_ip = %q, want %q", wxSV(req.SceneInfo.PayerClientIp), "203.0.113.10")
}
return &jsapi.PrepayWithRequestPaymentResponse{
Appid: core.String("wx123"),
TimeStamp: core.String("1712345678"),
NonceStr: core.String("nonce-123"),
Package: core.String("prepay_id=wx_prepay_123"),
SignType: core.String("RSA"),
PaySign: core.String("signed-payload"),
}, nil, nil
}
wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
nativeCalls++
return &native.PrepayResponse{}, nil, nil
}
wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
h5Calls++
return &h5.PrepayResponse{}, nil, nil
}
provider := &Wxpay{
config: map[string]string{
"appId": "wx123",
"mchId": "mch123",
},
coreClient: &core.Client{},
}
resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_88",
Amount: "66.88",
PaymentType: payment.TypeWxpay,
NotifyURL: "https://merchant.example/payment/notify",
OpenID: "openid-123",
ClientIP: "203.0.113.10",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if jsapiCalls != 1 {
t.Fatalf("jsapi prepay calls = %d, want 1", jsapiCalls)
}
if nativeCalls != 0 {
t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
}
if h5Calls != 0 {
t.Fatalf("h5 prepay calls = %d, want 0", h5Calls)
}
if resp.ResultType != payment.CreatePaymentResultJSAPIReady {
t.Fatalf("result type = %q, want %q", resp.ResultType, payment.CreatePaymentResultJSAPIReady)
}
if resp.JSAPI == nil {
t.Fatal("expected jsapi payload, got nil")
}
if resp.JSAPI.AppID != "wx123" {
t.Fatalf("jsapi appId = %q, want %q", resp.JSAPI.AppID, "wx123")
}
if resp.JSAPI.TimeStamp != "1712345678" {
t.Fatalf("jsapi timeStamp = %q, want %q", resp.JSAPI.TimeStamp, "1712345678")
}
if resp.JSAPI.NonceStr != "nonce-123" {
t.Fatalf("jsapi nonceStr = %q, want %q", resp.JSAPI.NonceStr, "nonce-123")
}
if resp.JSAPI.Package != "prepay_id=wx_prepay_123" {
t.Fatalf("jsapi package = %q, want %q", resp.JSAPI.Package, "prepay_id=wx_prepay_123")
}
if resp.JSAPI.SignType != "RSA" {
t.Fatalf("jsapi signType = %q, want %q", resp.JSAPI.SignType, "RSA")
}
if resp.JSAPI.PaySign != "signed-payload" {
t.Fatalf("jsapi paySign = %q, want %q", resp.JSAPI.PaySign, "signed-payload")
}
}
func TestCreatePaymentMobileH5IncludesConfiguredSceneInfo(t *testing.T) {
origJSAPIPrepay := wxpayJSAPIPrepayWithRequestPayment
origNativePrepay := wxpayNativePrepay
origH5Prepay := wxpayH5Prepay
t.Cleanup(func() {
wxpayJSAPIPrepayWithRequestPayment = origJSAPIPrepay
wxpayNativePrepay = origNativePrepay
wxpayH5Prepay = origH5Prepay
})
jsapiCalls := 0
nativeCalls := 0
h5Calls := 0
wxpayJSAPIPrepayWithRequestPayment = func(ctx context.Context, svc jsapi.JsapiApiService, req jsapi.PrepayRequest) (*jsapi.PrepayWithRequestPaymentResponse, *core.APIResult, error) {
jsapiCalls++
return &jsapi.PrepayWithRequestPaymentResponse{}, nil, nil
}
wxpayNativePrepay = func(ctx context.Context, svc native.NativeApiService, req native.PrepayRequest) (*native.PrepayResponse, *core.APIResult, error) {
nativeCalls++
return &native.PrepayResponse{}, nil, nil
}
wxpayH5Prepay = func(ctx context.Context, svc h5.H5ApiService, req h5.PrepayRequest) (*h5.PrepayResponse, *core.APIResult, error) {
h5Calls++
if req.SceneInfo == nil {
t.Fatal("expected scene_info, got nil")
}
if got := wxSV(req.SceneInfo.PayerClientIp); got != "203.0.113.10" {
t.Fatalf("scene_info payer_client_ip = %q, want %q", got, "203.0.113.10")
}
if req.SceneInfo.H5Info == nil {
t.Fatal("expected scene_info.h5_info, got nil")
}
if got := wxSV(req.SceneInfo.H5Info.Type); got != wxpayH5Type {
t.Fatalf("scene_info.h5_info.type = %q, want %q", got, wxpayH5Type)
}
if got := wxSV(req.SceneInfo.H5Info.AppName); got != "Sub2API" {
t.Fatalf("scene_info.h5_info.app_name = %q, want %q", got, "Sub2API")
}
if got := wxSV(req.SceneInfo.H5Info.AppUrl); got != "https://app.example.com" {
t.Fatalf("scene_info.h5_info.app_url = %q, want %q", got, "https://app.example.com")
}
return &h5.PrepayResponse{
H5Url: core.String("https://wx.tenpay.example/h5pay?prepay_id=1"),
}, nil, nil
}
provider := &Wxpay{
config: map[string]string{
"appId": "wx123",
"mchId": "mch123",
"h5AppName": "Sub2API",
"h5AppUrl": "https://app.example.com",
},
coreClient: &core.Client{},
}
resp, err := provider.CreatePayment(context.Background(), payment.CreatePaymentRequest{
OrderID: "sub2_99",
Amount: "66.88",
PaymentType: payment.TypeWxpay,
Subject: "Balance Recharge",
NotifyURL: "https://merchant.example/payment/notify",
ReturnURL: "https://merchant.example/payment/result?resume_token=resume-99",
ClientIP: "203.0.113.10",
IsMobile: true,
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if jsapiCalls != 0 {
t.Fatalf("jsapi prepay calls = %d, want 0", jsapiCalls)
}
if nativeCalls != 0 {
t.Fatalf("native prepay calls = %d, want 0", nativeCalls)
}
if h5Calls != 1 {
t.Fatalf("h5 prepay calls = %d, want 1", h5Calls)
}
if !strings.Contains(resp.PayURL, "redirect_url=") {
t.Fatalf("pay_url = %q, want redirect_url query appended", resp.PayURL)
}
}
......@@ -101,34 +101,69 @@ type CreatePaymentRequest struct {
Subject string // Product description
NotifyURL string // Webhook callback URL
ReturnURL string // Browser redirect URL after payment
OpenID string // WeChat JSAPI payer OpenID when available
ClientIP string // Payer's IP address
IsMobile bool // Whether the request comes from a mobile device
InstanceSubMethods string // Comma-separated sub-methods from instance supported_types (for Stripe)
}
// CreatePaymentResultType describes the shape of the create-payment result.
type CreatePaymentResultType = string
const (
CreatePaymentResultOrderCreated CreatePaymentResultType = "order_created"
CreatePaymentResultOAuthRequired CreatePaymentResultType = "oauth_required"
CreatePaymentResultJSAPIReady CreatePaymentResultType = "jsapi_ready"
)
// WechatOAuthInfo describes the next step when WeChat OAuth is required before payment.
type WechatOAuthInfo struct {
AuthorizeURL string `json:"authorize_url,omitempty"`
AppID string `json:"appid,omitempty"`
OpenID string `json:"openid,omitempty"`
Scope string `json:"scope,omitempty"`
State string `json:"state,omitempty"`
RedirectURL string `json:"redirect_url,omitempty"`
}
// WechatJSAPIPayload contains the fields the frontend needs to invoke WeChat JSAPI payment.
type WechatJSAPIPayload struct {
AppID string `json:"appId,omitempty"`
TimeStamp string `json:"timeStamp,omitempty"`
NonceStr string `json:"nonceStr,omitempty"`
Package string `json:"package,omitempty"`
SignType string `json:"signType,omitempty"`
PaySign string `json:"paySign,omitempty"`
}
// CreatePaymentResponse is returned after successfully initiating a payment.
type CreatePaymentResponse struct {
TradeNo string // Third-party transaction ID
PayURL string // H5 payment URL (alipay/wxpay)
QRCode string // QR code content for scanning
ClientSecret string // Stripe PaymentIntent client secret
TradeNo string // Third-party transaction ID
PayURL string // H5 payment URL (alipay/wxpay)
QRCode string // QR code content for scanning
ClientSecret string // Stripe PaymentIntent client secret
ResultType CreatePaymentResultType // Typed result contract for frontend flows
OAuth *WechatOAuthInfo // WeChat OAuth bootstrap payload when required
JSAPI *WechatJSAPIPayload // WeChat JSAPI invocation payload when ready
}
// QueryOrderResponse describes the payment status from the upstream provider.
type QueryOrderResponse struct {
TradeNo string
Status string // "pending", "paid", "failed", "refunded"
Amount float64 // Amount in CNY
PaidAt string // RFC3339 timestamp or empty
TradeNo string
Status string // "pending", "paid", "failed", "refunded"
Amount float64 // Amount in CNY
PaidAt string // RFC3339 timestamp or empty
Metadata map[string]string
}
// PaymentNotification is the parsed result of a webhook/notify callback.
type PaymentNotification struct {
TradeNo string
OrderID string
Amount float64
Status string // "success" or "failed"
RawData string // Raw notification body for audit
TradeNo string
OrderID string
Amount float64
Status string // "success" or "failed"
RawData string // Raw notification body for audit
Metadata map[string]string
}
// RefundRequest contains the parameters for requesting a refund.
......@@ -179,3 +214,9 @@ type CancelableProvider interface {
// CancelPayment cancels/expires a pending payment on the upstream platform.
CancelPayment(ctx context.Context, tradeNo string) error
}
// MerchantIdentityProvider exposes the current non-sensitive merchant identity
// derived from provider configuration for snapshot consistency checks.
type MerchantIdentityProvider interface {
MerchantIdentityMetadata() map[string]string
}
......@@ -19,13 +19,17 @@ func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementRea
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
client := clientFromContext(ctx, r.client)
return client.AnnouncementRead.Create().
err := client.AnnouncementRead.Create().
SetAnnouncementID(announcementID).
SetUserID(userID).
SetReadAt(readAt).
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
DoNothing().
Exec(ctx)
if isSQLNoRowsError(err) {
return nil
}
return err
}
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
......
......@@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
user.FieldSignupSource,
user.FieldLastLoginAt,
user.FieldLastActiveAt,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
......@@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
SignupSource: u.SignupSource,
LastLoginAt: u.LastLoginAt,
LastActiveAt: u.LastActiveAt,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
......
//go:build integration
package repository
import (
"context"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestAuthIdentityCompatBackfillMigration_AllowsLongReportTypes(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration108Path := filepath.Join("..", "..", "migrations", "108_auth_identity_foundation_core.sql")
migration108SQL, err := os.ReadFile(migration108Path)
require.NoError(t, err)
migration109Path := filepath.Join("..", "..", "migrations", "109_auth_identity_compat_backfill.sql")
migration109SQL, err := os.ReadFile(migration109Path)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
DROP TABLE IF EXISTS auth_identity_migration_reports CASCADE;
DROP TABLE IF EXISTS auth_identity_channels CASCADE;
DROP TABLE IF EXISTS identity_adoption_decisions CASCADE;
DROP TABLE IF EXISTS pending_auth_sessions CASCADE;
DROP TABLE IF EXISTS auth_identities CASCADE;
ALTER TABLE users
DROP COLUMN IF EXISTS signup_source,
DROP COLUMN IF EXISTS last_login_at,
DROP COLUMN IF EXISTS last_active_at;
`)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration108SQL))
require.NoError(t, err)
var userID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('oidc-demo-subject@oidc-connect.invalid', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&userID))
_, err = tx.ExecContext(ctx, string(migration109SQL))
require.NoError(t, err)
var reportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'oidc_synthetic_email_requires_manual_recovery'
AND report_key = $1
`, strconv.FormatInt(userID, 10)).Scan(&reportCount))
require.Equal(t, 1, reportCount)
var reportTypeLimit int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT character_maximum_length
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = 'auth_identity_migration_reports'
AND column_name = 'report_type'
`).Scan(&reportTypeLimit))
require.GreaterOrEqual(t, reportTypeLimit, 45)
require.NotZero(t, userID)
}
//go:build integration
package repository
import (
"context"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestAuthIdentityLegacyExternalBackfillMigration(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
var linuxDoUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoUserID))
var wechatUnionUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-union@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatUnionUserID))
var wechatOpenIDOnlyUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-openid@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatOpenIDOnlyUserID))
var syntheticAuthIdentityID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
VALUES ($1, 'wechat', 'wechat-main', 'openid-synthetic', '{"backfill_source":"synthetic_email"}'::jsonb)
RETURNING id`, wechatOpenIDOnlyUserID).Scan(&syntheticAuthIdentityID))
var linuxDoLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-user-1', NULL, 'linux-user', 'Linux User', '{"source":"legacy"}')
RETURNING id
`, linuxDoUserID).Scan(&linuxDoLegacyID))
var wechatUnionLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-union-1', 'union-1', 'wechat-union-user', 'WeChat Union User', '{"channel":"oa","appid":"wx-app-1"}')
RETURNING id
`, wechatUnionUserID).Scan(&wechatUnionLegacyID))
var wechatOpenIDLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-only-1', NULL, 'wechat-openid-user', 'WeChat OpenID User', '{"channel":"oa","appid":"wx-app-2"}')
RETURNING id
`, wechatOpenIDOnlyUserID).Scan(&wechatOpenIDLegacyID))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var linuxDoCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-user-1'
`, linuxDoUserID).Scan(&linuxDoCount))
require.Equal(t, 1, linuxDoCount)
var wechatSubject string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT provider_subject
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-1'
`, wechatUnionUserID).Scan(&wechatSubject))
require.Equal(t, "union-1", wechatSubject)
var wechatChannelCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_channels channel
JOIN auth_identities ai ON ai.id = channel.identity_id
WHERE ai.user_id = $1
AND channel.provider_type = 'wechat'
AND channel.provider_key = 'wechat-main'
AND channel.channel = 'oa'
AND channel.channel_app_id = 'wx-app-1'
AND channel.channel_subject = 'openid-union-1'
`, wechatUnionUserID).Scan(&wechatChannelCount))
require.Equal(t, 1, wechatChannelCount)
var legacyOpenIDOnlyReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDLegacyID, 10)).Scan(&legacyOpenIDOnlyReportCount))
require.Equal(t, 1, legacyOpenIDOnlyReportCount)
var syntheticReviewCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "synthetic_auth_identity:"+strconv.FormatInt(syntheticAuthIdentityID, 10)).Scan(&syntheticReviewCount))
require.Equal(t, 1, syntheticReviewCount)
var unionLegacyReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatUnionLegacyID, 10)).Scan(&unionLegacyReportCount))
require.Zero(t, unionLegacyReportCount)
require.NotZero(t, linuxDoLegacyID)
}
func TestAuthIdentityLegacyExternalBackfillMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
var beforeCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&beforeCount))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var afterCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount)
}
func TestAuthIdentityLegacyExternalMigrations_ChainHandlesMalformedAndNonObjectMetadata(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migration115Path := filepath.Join("..", "..", "migrations", "115_auth_identity_legacy_external_backfill.sql")
migration115SQL, err := os.ReadFile(migration115Path)
require.NoError(t, err)
migration116Path := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migration116SQL, err := os.ReadFile(migration116Path)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
var linuxDoMalformedUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-malformed@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoMalformedUserID))
var linuxDoArrayUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-linuxdo-array@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&linuxDoArrayUserID))
var wechatUnionArrayUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-array@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatUnionArrayUserID))
var wechatOpenIDArrayUserID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ('legacy-wechat-openid-array@example.com', 'hash', 'user', 'active', 0, 1)
RETURNING id`).Scan(&wechatOpenIDArrayUserID))
var linuxDoMalformedLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-malformed', NULL, 'legacy-linuxdo-malformed', 'Legacy LinuxDo Malformed', '{invalid')
RETURNING id
`, linuxDoMalformedUserID).Scan(&linuxDoMalformedLegacyID))
var linuxDoArrayLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-array', NULL, 'legacy-linuxdo-array', 'Legacy LinuxDo Array', '["legacy-linuxdo-array"]')
RETURNING id
`, linuxDoArrayUserID).Scan(&linuxDoArrayLegacyID))
var wechatUnionArrayLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-array', 'union-array', 'legacy-wechat-array', 'Legacy WeChat Array', '["legacy-wechat-array"]')
RETURNING id
`, wechatUnionArrayUserID).Scan(&wechatUnionArrayLegacyID))
var wechatOpenIDArrayLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-array-only', NULL, 'legacy-wechat-array-only', 'Legacy WeChat Array Only', '["legacy-wechat-openid-array"]')
RETURNING id
`, wechatOpenIDArrayUserID).Scan(&wechatOpenIDArrayLegacyID))
_, err = tx.ExecContext(ctx, string(migration115SQL))
require.NoError(t, err)
_, err = tx.ExecContext(ctx, string(migration116SQL))
require.NoError(t, err)
var linuxDoMalformedMetadataType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(metadata)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-malformed'
`, linuxDoMalformedUserID).Scan(&linuxDoMalformedMetadataType))
require.Equal(t, "object", linuxDoMalformedMetadataType)
var linuxDoArrayMetadataType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(metadata)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-array'
`, linuxDoArrayUserID).Scan(&linuxDoArrayMetadataType))
require.Equal(t, "object", linuxDoArrayMetadataType)
var wechatUnionArrayMetadataType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(metadata)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'wechat'
AND provider_key = 'wechat-main'
AND provider_subject = 'union-array'
`, wechatUnionArrayUserID).Scan(&wechatUnionArrayMetadataType))
require.Equal(t, "object", wechatUnionArrayMetadataType)
var invalidJSONReportDetailsType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(details)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(linuxDoMalformedLegacyID, 10)).Scan(&invalidJSONReportDetailsType))
require.Equal(t, "object", invalidJSONReportDetailsType)
var openIDOnlyReportDetailsType string
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT jsonb_typeof(details)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatOpenIDArrayLegacyID, 10)).Scan(&openIDOnlyReportDetailsType))
require.Equal(t, "object", openIDOnlyReportDetailsType)
var preservedArrayMetadataCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE id IN (
SELECT id
FROM auth_identities
WHERE (user_id = $1 AND provider_subject = 'linuxdo-array')
OR (user_id = $2 AND provider_subject = 'union-array')
)
AND metadata ? '_legacy_metadata_raw_json'
`, linuxDoArrayUserID, wechatUnionArrayUserID).Scan(&preservedArrayMetadataCount))
require.Equal(t, 2, preservedArrayMetadataCount)
require.NotZero(t, linuxDoArrayLegacyID)
require.NotZero(t, wechatUnionArrayLegacyID)
}
func TestAuthIdentityLegacyExternalSafetyMigration_ReportsConflictsAndDowngradesInvalidJSON(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
_, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS user_external_identities (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
provider_union_id TEXT NULL,
provider_username TEXT NOT NULL DEFAULT '',
display_name TEXT NOT NULL DEFAULT '',
profile_url TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
metadata TEXT NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
);
TRUNCATE TABLE
auth_identity_channels,
identity_adoption_decisions,
auth_identities,
auth_identity_migration_reports,
user_external_identities,
users
RESTART IDENTITY CASCADE;
`)
require.NoError(t, err)
userIDs := make([]int64, 0, 8)
for _, email := range []string{
"linuxdo-conflict-legacy@example.com",
"linuxdo-conflict-owner@example.com",
"wechat-conflict-legacy@example.com",
"wechat-conflict-owner@example.com",
"wechat-channel-legacy@example.com",
"wechat-channel-owner@example.com",
"linuxdo-invalid-json@example.com",
"wechat-openid-invalid-json@example.com",
} {
var userID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO users (email, password_hash, role, status, balance, concurrency)
VALUES ($1, 'hash', 'user', 'active', 0, 1)
RETURNING id`, email).Scan(&userID))
userIDs = append(userIDs, userID)
}
linuxdoConflictLegacyUserID := userIDs[0]
linuxdoConflictOwnerUserID := userIDs[1]
wechatConflictLegacyUserID := userIDs[2]
wechatConflictOwnerUserID := userIDs[3]
wechatChannelLegacyUserID := userIDs[4]
wechatChannelOwnerUserID := userIDs[5]
linuxdoInvalidJSONUserID := userIDs[6]
wechatInvalidOpenIDUserID := userIDs[7]
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
VALUES ($1, 'linuxdo', 'linuxdo', 'linuxdo-conflict', '{}'::jsonb)
RETURNING id`, linuxdoConflictOwnerUserID).Scan(new(int64)))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
VALUES ($1, 'wechat', 'wechat-main', 'union-conflict', '{}'::jsonb)
RETURNING id`, wechatConflictOwnerUserID).Scan(new(int64)))
var wechatChannelOwnerIdentityID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO auth_identities (user_id, provider_type, provider_key, provider_subject, metadata)
VALUES ($1, 'wechat', 'wechat-main', 'union-channel-owner', '{}'::jsonb)
RETURNING id`, wechatChannelOwnerUserID).Scan(&wechatChannelOwnerIdentityID))
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO auth_identity_channels (
identity_id,
provider_type,
provider_key,
channel,
channel_app_id,
channel_subject,
metadata
)
VALUES ($1, 'wechat', 'wechat-main', 'oa', 'wx-app-conflict', 'openid-channel-conflict', '{}'::jsonb)
RETURNING id`, wechatChannelOwnerIdentityID).Scan(new(int64)))
var linuxdoConflictLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-conflict', NULL, 'legacy-linuxdo', 'Legacy LinuxDo Conflict', '{"source":"legacy"}')
RETURNING id
`, linuxdoConflictLegacyUserID).Scan(&linuxdoConflictLegacyID))
var wechatConflictLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-union-conflict', 'union-conflict', 'legacy-wechat', 'Legacy WeChat Conflict', '{"channel":"oa","appid":"wx-app-conflict-canon"}')
RETURNING id
`, wechatConflictLegacyUserID).Scan(&wechatConflictLegacyID))
var wechatChannelConflictLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-channel-conflict', 'union-channel-legacy', 'legacy-wechat-channel', 'Legacy WeChat Channel Conflict', '{"channel":"oa","appid":"wx-app-conflict"}')
RETURNING id
`, wechatChannelLegacyUserID).Scan(&wechatChannelConflictLegacyID))
var linuxdoInvalidJSONLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'linuxdo', 'linuxdo-invalid-json', NULL, 'legacy-linuxdo-invalid', 'Legacy LinuxDo Invalid JSON', '{invalid')
RETURNING id
`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidJSONLegacyID))
var wechatInvalidOpenIDLegacyID int64
require.NoError(t, tx.QueryRowContext(ctx, `
INSERT INTO user_external_identities (
user_id,
provider,
provider_user_id,
provider_union_id,
provider_username,
display_name,
metadata
) VALUES ($1, 'wechat', 'openid-invalid-json-only', NULL, 'legacy-wechat-invalid', 'Legacy WeChat Invalid JSON', '{still-invalid')
RETURNING id
`, wechatInvalidOpenIDUserID).Scan(&wechatInvalidOpenIDLegacyID))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var linuxdoConflictReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(linuxdoConflictLegacyID, 10)).Scan(&linuxdoConflictReportCount))
require.Equal(t, 1, linuxdoConflictReportCount)
var wechatConflictReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_conflict'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatConflictLegacyID, 10)).Scan(&wechatConflictReportCount))
require.Equal(t, 1, wechatConflictReportCount)
var channelConflictReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_channel_conflict'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatChannelConflictLegacyID, 10)).Scan(&channelConflictReportCount))
require.Equal(t, 1, channelConflictReportCount)
var invalidJSONReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'legacy_external_identity_invalid_metadata_json'
AND report_key IN ($1, $2)
`, "legacy_external_identity:"+strconv.FormatInt(linuxdoInvalidJSONLegacyID, 10), "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&invalidJSONReportCount))
require.Equal(t, 2, invalidJSONReportCount)
var linuxdoInvalidIdentityCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identities
WHERE user_id = $1
AND provider_type = 'linuxdo'
AND provider_key = 'linuxdo'
AND provider_subject = 'linuxdo-invalid-json'
`, linuxdoInvalidJSONUserID).Scan(&linuxdoInvalidIdentityCount))
require.Equal(t, 1, linuxdoInvalidIdentityCount)
var wechatOpenIDOnlyReportCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
WHERE report_type = 'wechat_openid_only_requires_remediation'
AND report_key = $1
`, "legacy_external_identity:"+strconv.FormatInt(wechatInvalidOpenIDLegacyID, 10)).Scan(&wechatOpenIDOnlyReportCount))
require.Equal(t, 1, wechatOpenIDOnlyReportCount)
}
func TestAuthIdentityLegacyExternalSafetyMigration_IsSafeWhenLegacyTableMissing(t *testing.T) {
tx := testTx(t)
ctx := context.Background()
migrationPath := filepath.Join("..", "..", "migrations", "116_auth_identity_legacy_external_safety_reports.sql")
migrationSQL, err := os.ReadFile(migrationPath)
require.NoError(t, err)
var beforeCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&beforeCount))
_, err = tx.ExecContext(ctx, string(migrationSQL))
require.NoError(t, err)
var afterCount int
require.NoError(t, tx.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM auth_identity_migration_reports
`).Scan(&afterCount))
require.Equal(t, beforeCount, afterCount)
}
......@@ -73,6 +73,12 @@ var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibil
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
},
},
"109_auth_identity_compat_backfill.sql": {
fileChecksum: "551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
acceptedDBChecksum: map[string]struct{}{
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3": {},
},
},
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
......
......@@ -51,4 +51,13 @@ func TestIsMigrationChecksumCompatible(t *testing.T) {
)
require.False(t, ok)
})
t.Run("109历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"109_auth_identity_compat_backfill.sql",
"2b380305e73ff0c13aa8c811e45897f2b36ca4a438f7b3e8f98e19ecb6bae0b3",
"551e498aa5616d2d91096e9d72cf9fb36e418ee22eacc557f8811cadbc9e20ee",
)
require.True(t, ok)
})
}
package repository
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"unsafe"
entsql "entgo.io/ent/dialect/sql"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
var (
ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT",
"auth identity already belongs to another user",
)
ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
"auth identity channel already belongs to another user",
)
ErrAuthIdentityChannelProviderMismatch = infraerrors.BadRequest(
"AUTH_IDENTITY_CHANNEL_PROVIDER_MISMATCH",
"auth identity channel provider must match canonical identity",
)
)
type ProviderGrantReason string
const (
ProviderGrantReasonSignup ProviderGrantReason = "signup"
ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
)
type AuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type AuthIdentityChannelKey struct {
ProviderType string
ProviderKey string
Channel string
ChannelAppID string
ChannelSubject string
}
type CreateAuthIdentityInput struct {
UserID int64
Canonical AuthIdentityKey
Channel *AuthIdentityChannelKey
Issuer *string
VerifiedAt *time.Time
Metadata map[string]any
ChannelMetadata map[string]any
}
type BindAuthIdentityInput = CreateAuthIdentityInput
type CreateAuthIdentityResult struct {
Identity *dbent.AuthIdentity
Channel *dbent.AuthIdentityChannel
}
func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
if r == nil || r.Identity == nil {
return AuthIdentityKey{}
}
return AuthIdentityKey{
ProviderType: r.Identity.ProviderType,
ProviderKey: r.Identity.ProviderKey,
ProviderSubject: r.Identity.ProviderSubject,
}
}
func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
if r == nil || r.Channel == nil {
return nil
}
return &AuthIdentityChannelKey{
ProviderType: r.Channel.ProviderType,
ProviderKey: r.Channel.ProviderKey,
Channel: r.Channel.Channel,
ChannelAppID: r.Channel.ChannelAppID,
ChannelSubject: r.Channel.ChannelSubject,
}
}
type UserAuthIdentityLookup struct {
User *dbent.User
Identity *dbent.AuthIdentity
Channel *dbent.AuthIdentityChannel
}
type ProviderGrantRecordInput struct {
UserID int64
ProviderType string
GrantReason ProviderGrantReason
}
type IdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID *int64
AdoptDisplayName bool
AdoptAvatar bool
}
type sqlQueryExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
if dbent.TxFromContext(ctx) != nil {
return fn(ctx)
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx); err != nil {
return err
}
return tx.Commit()
}
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
return nil, err
}
client := clientFromContext(ctx, r.client)
create := client.AuthIdentity.Create().
SetUserID(input.UserID).
SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
SetMetadata(copyMetadata(input.Metadata)).
SetNillableIssuer(input.Issuer).
SetNillableVerifiedAt(input.VerifiedAt)
identity, err := create.Save(ctx)
if err != nil {
return nil, err
}
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
SetChannel(strings.TrimSpace(input.Channel.Channel)).
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
SetMetadata(copyMetadata(input.ChannelMetadata)).
Save(ctx)
if err != nil {
return nil, err
}
}
return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
}
func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
).
WithUser().
Only(ctx)
if err != nil {
return nil, err
}
return &UserAuthIdentityLookup{
User: identity.Edges.User,
Identity: identity,
}, nil
}
func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
).
WithIdentity(func(q *dbent.AuthIdentityQuery) {
q.WithUser()
}).
Only(ctx)
if err != nil {
return nil, err
}
return &UserAuthIdentityLookup{
User: channel.Edges.Identity.Edges.User,
Identity: channel.Edges.Identity,
Channel: channel,
}, nil
}
func (r *userRepository) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
identities, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
Where(authidentity.UserIDEQ(userID)).
All(ctx)
if err != nil {
return nil, err
}
records := make([]service.UserAuthIdentityRecord, 0, len(identities))
for _, identity := range identities {
if identity == nil {
continue
}
records = append(records, service.UserAuthIdentityRecord{
ProviderType: strings.TrimSpace(identity.ProviderType),
ProviderKey: strings.TrimSpace(identity.ProviderKey),
ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
VerifiedAt: identity.VerifiedAt,
Issuer: identity.Issuer,
Metadata: copyMetadata(identity.Metadata),
CreatedAt: identity.CreatedAt,
UpdatedAt: identity.UpdatedAt,
})
}
return records, nil
}
func (r *userRepository) UnbindUserAuthProvider(ctx context.Context, userID int64, provider string) error {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" || provider == "email" {
return service.ErrIdentityProviderInvalid
}
return r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
identityIDs, err := client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ(provider),
).
IDs(txCtx)
if err != nil {
return err
}
if len(identityIDs) == 0 {
return nil
}
if _, err := client.IdentityAdoptionDecision.Update().
Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
ClearIdentityID().
Save(txCtx); err != nil {
return err
}
if _, err := client.AuthIdentityChannel.Delete().
Where(authidentitychannel.IdentityIDIn(identityIDs...)).
Exec(txCtx); err != nil {
return err
}
_, err = client.AuthIdentity.Delete().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ(provider),
).
Exec(txCtx)
return err
})
}
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
if err := validateAuthIdentityChannelProviderMatch(input.Canonical, input.Channel); err != nil {
return nil, err
}
var result *CreateAuthIdentityResult
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
canonical := input.Canonical
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
).
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
return err
}
if identity != nil && identity.UserID != input.UserID {
return ErrAuthIdentityOwnershipConflict
}
if identity == nil {
identity, err = client.AuthIdentity.Create().
SetUserID(input.UserID).
SetProviderType(strings.TrimSpace(canonical.ProviderType)).
SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
SetMetadata(copyMetadata(input.Metadata)).
SetNillableIssuer(input.Issuer).
SetNillableVerifiedAt(input.VerifiedAt).
Save(txCtx)
if err != nil {
return err
}
} else {
update := client.AuthIdentity.UpdateOneID(identity.ID)
if input.Metadata != nil {
update = update.SetMetadata(copyMetadata(input.Metadata))
}
if input.Issuer != nil {
update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
}
if input.VerifiedAt != nil {
update = update.SetVerifiedAt(*input.VerifiedAt)
}
identity, err = update.Save(txCtx)
if err != nil {
return err
}
}
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
).
WithIdentity().
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
return err
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
return ErrAuthIdentityChannelOwnershipConflict
}
if channel == nil {
channel, err = client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
SetChannel(strings.TrimSpace(input.Channel.Channel)).
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
SetMetadata(copyMetadata(input.ChannelMetadata)).
Save(txCtx)
if err != nil {
return err
}
} else {
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID)
if input.ChannelMetadata != nil {
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
}
channel, err = update.Save(txCtx)
if err != nil {
return err
}
}
}
result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return false, fmt.Errorf("sql executor is not configured")
}
result, err := exec.ExecContext(ctx, `
INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
input.UserID,
strings.TrimSpace(input.ProviderType),
string(input.GrantReason),
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
client := clientFromContext(ctx, r.client)
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := client.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.NEQ(col, input.PendingAuthSessionID),
))
}),
).
ClearIdentityID().
Save(ctx); err != nil {
return nil, err
}
}
current, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
now := time.Now().UTC()
if current == nil {
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(now)
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
}
return update.Save(ctx)
}
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
Only(ctx)
}
func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
SetLastLoginAt(loginAt).
Save(ctx)
return err
}
func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
SetLastActiveAt(activeAt).
Save(ctx)
return err
}
func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return nil, err
}
rows, err := exec.QueryContext(ctx, `
SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
FROM user_avatars
WHERE user_id = $1`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, rows.Err()
}
var avatar service.UserAvatar
if err := rows.Scan(
&avatar.StorageProvider,
&avatar.StorageKey,
&avatar.URL,
&avatar.ContentType,
&avatar.ByteSize,
&avatar.SHA256,
); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return &avatar, nil
}
func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return nil, err
}
_, err = exec.ExecContext(ctx, `
INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (user_id) DO UPDATE SET
storage_provider = EXCLUDED.storage_provider,
storage_key = EXCLUDED.storage_key,
url = EXCLUDED.url,
content_type = EXCLUDED.content_type,
byte_size = EXCLUDED.byte_size,
sha256 = EXCLUDED.sha256,
updated_at = NOW()`,
userID,
strings.TrimSpace(input.StorageProvider),
strings.TrimSpace(input.StorageKey),
strings.TrimSpace(input.URL),
strings.TrimSpace(input.ContentType),
input.ByteSize,
strings.TrimSpace(input.SHA256),
)
if err != nil {
return nil, err
}
return &service.UserAvatar{
StorageProvider: strings.TrimSpace(input.StorageProvider),
StorageKey: strings.TrimSpace(input.StorageKey),
URL: strings.TrimSpace(input.URL),
ContentType: strings.TrimSpace(input.ContentType),
ByteSize: input.ByteSize,
SHA256: strings.TrimSpace(input.SHA256),
}, nil
}
func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return err
}
_, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
return err
}
func copyMetadata(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func validateAuthIdentityChannelProviderMatch(canonical AuthIdentityKey, channel *AuthIdentityChannelKey) error {
if channel == nil {
return nil
}
canonicalProviderType := strings.TrimSpace(canonical.ProviderType)
canonicalProviderKey := strings.TrimSpace(canonical.ProviderKey)
channelProviderType := strings.TrimSpace(channel.ProviderType)
channelProviderKey := strings.TrimSpace(channel.ProviderKey)
if canonicalProviderType != channelProviderType || canonicalProviderKey != channelProviderKey {
return ErrAuthIdentityChannelProviderMismatch
}
return nil
}
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
if tx := dbent.TxFromContext(ctx); tx != nil {
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
return exec
}
}
if fallback != nil {
return fallback
}
return sqlExecutorFromEntClient(client)
}
func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
return exec, nil
}
func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
if client == nil {
return nil
}
clientValue := reflect.ValueOf(client).Elem()
configValue := clientValue.FieldByName("config")
driverValue := configValue.FieldByName("driver")
if !driverValue.IsValid() {
return nil
}
driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
exec, ok := driver.(sqlQueryExecutor)
if !ok {
return nil
}
return exec
}
//go:build integration
package repository
import (
"context"
"errors"
"fmt"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type UserProfileIdentityRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *userRepository
}
func TestUserProfileIdentityRepoSuite(t *testing.T) {
suite.Run(t, new(UserProfileIdentityRepoSuite))
}
func (s *UserProfileIdentityRepoSuite) SetupTest() {
s.ctx = context.Background()
s.client = testEntClient(s.T())
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
_, err := integrationDB.ExecContext(s.ctx, `
TRUNCATE TABLE
identity_adoption_decisions,
auth_identity_channels,
auth_identities,
pending_auth_sessions,
user_provider_default_grants,
user_avatars
RESTART IDENTITY`)
s.Require().NoError(err)
}
func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
s.T().Helper()
user, err := s.client.User.Create().
SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
SetPasswordHash("test-password-hash").
SetRole("user").
SetStatus("active").
Save(s.ctx)
s.Require().NoError(err)
return user
}
func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
s.T().Helper()
session, err := s.client.PendingAuthSession.Create().
SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
SetIntent("bind_current_user").
SetProviderType(key.ProviderType).
SetProviderKey(key.ProviderKey).
SetProviderSubject(key.ProviderSubject).
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
SetLocalFlowState(map[string]any{"step": "pending"}).
Save(s.ctx)
s.Require().NoError(err)
return session
}
func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
user := s.mustCreateUser("canonical-channel")
verifiedAt := time.Now().UTC().Truncate(time.Second)
created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
Channel: "mp",
ChannelAppID: "wx-app",
ChannelSubject: "openid-123",
},
Issuer: stringPtr("https://issuer.example"),
VerifiedAt: &verifiedAt,
Metadata: map[string]any{"unionid": "union-123"},
ChannelMetadata: map[string]any{"openid": "openid-123"},
})
s.Require().NoError(err)
s.Require().NotNil(created.Identity)
s.Require().NotNil(created.Channel)
canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
s.Require().NoError(err)
s.Require().Equal(user.ID, canonical.User.ID)
s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
s.Require().NoError(err)
s.Require().Equal(user.ID, channel.User.ID)
s.Require().Equal(created.Identity.ID, channel.Identity.ID)
s.Require().Equal(created.Channel.ID, channel.Channel.ID)
}
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
owner := s.mustCreateUser("owner")
other := s.mustCreateUser("other")
first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: owner.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
Metadata: map[string]any{"username": "first"},
ChannelMetadata: map[string]any{"scope": "read"},
})
s.Require().NoError(err)
second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: owner.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
Metadata: map[string]any{"username": "second"},
ChannelMetadata: map[string]any{"scope": "write"},
})
s.Require().NoError(err)
s.Require().Equal(first.Identity.ID, second.Identity.ID)
s.Require().Equal(first.Channel.ID, second.Channel.ID)
s.Require().Equal("second", second.Identity.Metadata["username"])
s.Require().Equal("write", second.Channel.Metadata["scope"])
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: other.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: other.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-2",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
}
func (s *UserProfileIdentityRepoSuite) TestCreateAuthIdentity_RejectsChannelProviderMismatch() {
user := s.mustCreateUser("provider-mismatch-create")
_, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-create-mismatch",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "app-mismatch",
ChannelSubject: "openid-create-mismatch",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
}
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_RejectsChannelProviderMismatch() {
user := s.mustCreateUser("provider-mismatch-bind")
_, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-bind-mismatch",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-legacy",
Channel: "oa",
ChannelAppID: "wx-app-bind-mismatch",
ChannelSubject: "openid-bind-mismatch",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityChannelProviderMismatch)
}
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
user := s.mustCreateUser("tx-rollback")
expectedErr := errors.New("rollback")
err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
_, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-rollback",
},
})
s.Require().NoError(err)
inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "oidc",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().True(inserted)
return expectedErr
})
s.Require().ErrorIs(err, expectedErr)
_, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-rollback",
})
s.Require().True(dbent.IsNotFound(err))
var count int
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT COUNT(*)
FROM user_provider_default_grants
WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
user.ID,
"oidc",
string(ProviderGrantReasonFirstBind),
).Scan(&count))
s.Require().Zero(count)
}
func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
user := s.mustCreateUser("grant")
inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().True(inserted)
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().False(inserted)
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonSignup,
})
s.Require().NoError(err)
s.Require().True(inserted)
var count int
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT COUNT(*)
FROM user_provider_default_grants
WHERE user_id = $1 AND provider_type = $2`,
user.ID,
"wechat",
).Scan(&count))
s.Require().Equal(2, count)
}
func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
user := s.mustCreateUser("adoption")
identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption",
},
})
s.Require().NoError(err)
session := s.mustCreatePendingAuthSession(identity.IdentityRef())
first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
s.Require().NoError(err)
s.Require().True(first.AdoptDisplayName)
s.Require().False(first.AdoptAvatar)
s.Require().Nil(first.IdentityID)
second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.Identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
})
s.Require().NoError(err)
s.Require().Equal(first.ID, second.ID)
s.Require().NotNil(second.IdentityID)
s.Require().Equal(identity.Identity.ID, *second.IdentityID)
s.Require().True(second.AdoptAvatar)
loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
s.Require().NoError(err)
s.Require().Equal(second.ID, loaded.ID)
s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
}
func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_ReassignsExistingIdentityReference() {
user := s.mustCreateUser("adoption-reassign")
identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption-reassign",
},
})
s.Require().NoError(err)
firstSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
firstDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: firstSession.ID,
IdentityID: &identity.Identity.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
s.Require().NoError(err)
s.Require().NotNil(firstDecision.IdentityID)
s.Require().Equal(identity.Identity.ID, *firstDecision.IdentityID)
secondSession := s.mustCreatePendingAuthSession(identity.IdentityRef())
secondDecision, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: secondSession.ID,
IdentityID: &identity.Identity.ID,
AdoptDisplayName: false,
AdoptAvatar: true,
})
s.Require().NoError(err)
s.Require().NotNil(secondDecision.IdentityID)
s.Require().Equal(identity.Identity.ID, *secondDecision.IdentityID)
reloadedFirst, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, firstSession.ID)
s.Require().NoError(err)
s.Require().Nil(reloadedFirst.IdentityID)
}
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_AllowsAvatarOnlyProfileUpdate() {
user := s.mustCreateUser("avatar-only-update")
model, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(model)
err = s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
_, err := s.repo.UpsertUserAvatar(txCtx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/avatar.png",
})
if err != nil {
return err
}
return s.repo.Update(txCtx, model)
})
s.Require().NoError(err)
avatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(avatar)
s.Require().Equal("https://cdn.example.com/avatar.png", avatar.URL)
}
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
user := s.mustCreateUser("avatar")
inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "inline",
URL: "data:image/png;base64,QUJD",
ContentType: "image/png",
ByteSize: 3,
SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
})
s.Require().NoError(err)
s.Require().Equal("inline", inlineAvatar.StorageProvider)
s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(loadedAvatar)
s.Require().Equal("image/png", loadedAvatar.ContentType)
s.Require().Equal(3, loadedAvatar.ByteSize)
_, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/avatar.png",
})
s.Require().NoError(err)
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(loadedAvatar)
s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
s.Require().Zero(loadedAvatar.ByteSize)
s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Nil(loadedAvatar)
}
func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
user := s.mustCreateUser("activity")
loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
activeAt := loginAt.Add(5 * time.Minute)
s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
var storedLoginAt sqlNullTime
var storedActiveAt sqlNullTime
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT last_login_at, last_active_at
FROM users
WHERE id = $1`,
user.ID,
).Scan(&storedLoginAt, &storedActiveAt))
s.Require().True(storedLoginAt.Valid)
s.Require().True(storedActiveAt.Valid)
s.Require().True(storedLoginAt.Time.Equal(loginAt))
s.Require().True(storedActiveAt.Time.Equal(activeAt))
}
type sqlNullTime struct {
Time time.Time
Valid bool
}
func (t *sqlNullTime) Scan(value any) error {
switch v := value.(type) {
case time.Time:
t.Time = v
t.Valid = true
return nil
case nil:
t.Time = time.Time{}
t.Valid = false
return nil
default:
return fmt.Errorf("unsupported scan type %T", value)
}
}
func stringPtr(v string) *string {
return &v
}
......@@ -11,12 +11,17 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
entsql "entgo.io/ent/dialect/sql"
)
......@@ -51,8 +56,12 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
created, err := txClient.User.Create().
......@@ -64,6 +73,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
......@@ -72,6 +84,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := ensureEmailAuthIdentityWithClient(ctx, txClient, created.ID, created.Email, "user_repo_create"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
......@@ -101,10 +116,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User,
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
matches, err := r.client.User.Query().
Where(userEmailLookupPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
return nil, err
}
if len(matches) == 0 {
return nil, service.ErrUserNotFound
}
if len(matches) > 1 {
return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email))
}
m := matches[0]
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
......@@ -133,9 +158,18 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
// 已处于外部事务中(ErrTxStarted),复用当前事务 client 并由调用方负责提交/回滚。
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
existing, err := clientFromContext(ctx, txClient).User.Get(ctx, userIn.ID)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
oldEmail := existing.Email
updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
......@@ -151,6 +185,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
}
if userIn.LastLoginAt != nil {
updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
}
if userIn.LastActiveAt != nil {
updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
}
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
......@@ -162,6 +205,9 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if err := replaceEmailAuthIdentityWithClient(ctx, txClient, updated.ID, oldEmail, updated.Email, "user_repo_update"); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
......@@ -173,14 +219,146 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
return nil
}
func ensureEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, email string, source string) error {
client = clientFromContext(ctx, client)
if client == nil || userID <= 0 {
return nil
}
subject := normalizeEmailAuthIdentitySubject(email)
if subject == "" {
return nil
}
if err := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(subject).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": source}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return err
}
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(subject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity.UserID != userID {
return ErrAuthIdentityOwnershipConflict
}
return nil
}
func replaceEmailAuthIdentityWithClient(ctx context.Context, client *dbent.Client, userID int64, oldEmail, newEmail string, source string) error {
newSubject := normalizeEmailAuthIdentitySubject(newEmail)
if err := ensureEmailAuthIdentityWithClient(ctx, client, userID, newEmail, source); err != nil {
return err
}
oldSubject := normalizeEmailAuthIdentitySubject(oldEmail)
if oldSubject == "" || oldSubject == newSubject {
return nil
}
_, err := clientFromContext(ctx, client).AuthIdentity.Delete().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(oldSubject),
).
Exec(ctx)
return err
}
func normalizeEmailAuthIdentitySubject(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
return ""
}
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
return ""
}
return normalized
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
if existingTx := dbent.TxFromContext(ctx); existingTx != nil {
txClient = existingTx.Client()
} else {
txClient = r.client
}
}
identityIDs, err := txClient.AuthIdentity.Query().
Where(authidentity.UserIDEQ(id)).
IDs(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if len(identityIDs) > 0 {
if _, err := txClient.IdentityAdoptionDecision.Update().
Where(identityadoptiondecision.IdentityIDIn(identityIDs...)).
ClearIdentityID().
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if _, err := txClient.AuthIdentityChannel.Delete().
Where(authidentitychannel.IdentityIDIn(identityIDs...)).
Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if _, err := txClient.AuthIdentity.Delete().
Where(authidentity.UserIDEQ(id)).
Exec(ctx); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
}
affected, err := txClient.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if affected == 0 {
return service.ErrUserNotFound
}
if tx != nil {
if err := tx.Commit(); err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
}
return nil
}
......@@ -298,8 +476,13 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
if sortBy == "last_used_at" {
return userLastUsedAtOrder(sortOrder)
}
var field string
defaultField := true
nullsLastField := false
switch sortBy {
case "email":
field = dbuser.FieldEmail
......@@ -322,6 +505,10 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at":
field = dbuser.FieldCreatedAt
defaultField = false
case "last_active_at":
field = dbuser.FieldLastActiveAt
defaultField = false
nullsLastField = true
default:
field = dbuser.FieldID
}
......@@ -330,14 +517,92 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
dbent.Asc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
}
if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
}
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
dbent.Desc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
}
func (r *userRepository) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
result := make(map[int64]*time.Time, len(userIDs))
if len(userIDs) == 0 {
return result, nil
}
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
const query = `
SELECT user_id, MAX(created_at) AS last_used_at
FROM usage_logs
WHERE user_id = ANY($1)
GROUP BY user_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var (
userID int64
lastUsedAt time.Time
)
if scanErr := rows.Scan(&userID, &lastUsedAt); scanErr != nil {
return nil, scanErr
}
ts := lastUsedAt.UTC()
result[userID] = &ts
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
latestByUserID, err := r.GetLatestUsedAtByUserIDs(ctx, []int64{userID})
if err != nil {
return nil, err
}
return latestByUserID[userID], nil
}
func userLastUsedAtOrder(sortOrder string) []func(*entsql.Selector) {
orderExpr := func(direction, nulls string, tieOrder func(string) string) func(*entsql.Selector) {
return func(s *entsql.Selector) {
subquery := fmt.Sprintf("(SELECT MAX(created_at) FROM usage_logs WHERE user_id = %s)", s.C(dbuser.FieldID))
s.OrderExpr(entsql.Expr(subquery + " " + direction + " NULLS " + nulls))
s.OrderBy(tieOrder(s.C(dbuser.FieldID)))
}
}
if sortOrder == pagination.SortOrderAsc {
return []func(*entsql.Selector){
orderExpr("ASC", "FIRST", entsql.Asc),
}
}
return []func(*entsql.Selector){
orderExpr("DESC", "LAST", entsql.Desc),
}
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
......@@ -436,17 +701,36 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx)
}
func userEmailLookupPredicate(email string) predicate.User {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
return dbuser.EmailEQ(email)
}
return predicate.User(func(s *entsql.Selector) {
s.Where(entsql.P(func(b *entsql.Builder) {
b.WriteString("LOWER(TRIM(").
Ident(s.C(dbuser.FieldEmail)).
WriteString(")) = ").
Arg(normalized)
}))
})
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
return client.UserAllowedGroup.Create().
err := client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
if isSQLNoRowsError(err) {
return nil
}
return err
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
......@@ -546,6 +830,9 @@ func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, cl
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
if isSQLNoRowsError(err) {
return nil
}
return err
}
}
......@@ -558,10 +845,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return
}
dst.ID = src.ID
dst.SignupSource = src.SignupSource
dst.LastLoginAt = src.LastLoginAt
dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
func userSignupSourceOrDefault(signupSource string) string {
signupSource = strings.TrimSpace(signupSource)
if signupSource == "" {
return "email"
}
return signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
......
//go:build integration
package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *UserRepoSuite) TestCreate_CreatesEmailAuthIdentityForNormalEmail() {
user := &service.User{
Email: "repo-create@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 2,
}
s.Require().NoError(s.repo.Create(s.ctx, user))
identity, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("repo-create@example.com"),
).
Only(s.ctx)
s.Require().NoError(err)
s.Require().Equal(user.ID, identity.UserID)
}
func (s *UserRepoSuite) TestCreate_SkipsEmailAuthIdentityForSyntheticLinuxDoEmail() {
user := &service.User{
Email: "linuxdo-legacy-user@linuxdo-connect.invalid",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 2,
}
s.Require().NoError(s.repo.Create(s.ctx, user))
count, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Zero(count)
}
func (s *UserRepoSuite) TestUpdate_ReplacesEmailAuthIdentityWhenEmailChanges() {
user := s.mustCreateUser(&service.User{
Email: "before-update@example.com",
})
user.Email = "after-update@example.com"
s.Require().NoError(s.repo.Update(s.ctx, user))
newIdentity, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("after-update@example.com"),
).
Only(s.ctx)
s.Require().NoError(err)
s.Require().Equal(user.ID, newIdentity.UserID)
oldCount, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("before-update@example.com"),
).
Count(context.Background())
s.Require().NoError(err)
s.Require().Zero(oldCount)
}
package repository
import (
"context"
"database/sql"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return newUserRepositoryWithSQL(client, db), client
}
func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
repo, _ := newUserEntRepo(t)
ctx := context.Background()
err := repo.Create(ctx, &service.User{
Email: " Legacy@Example.com ",
Username: "legacy-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
require.NoError(t, err)
got, err := repo.GetByEmail(ctx, "legacy@example.com")
require.NoError(t, err)
require.Equal(t, " Legacy@Example.com ", got.Email)
}
func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) {
repo, _ := newUserEntRepo(t)
ctx := context.Background()
err := repo.Create(ctx, &service.User{
Email: " Legacy@Example.com ",
Username: "legacy-user",
PasswordHash: "hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
require.NoError(t, err)
exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ")
require.NoError(t, err)
require.True(t, exists)
}
......@@ -8,6 +8,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
......@@ -26,6 +28,8 @@ func (s *UserRepoSuite) SetupTest() {
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identity_channels")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM auth_identities")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
......@@ -122,11 +126,27 @@ func (s *UserRepoSuite) TestGetByEmail() {
s.Require().Equal(user.ID, got.ID)
}
func (s *UserRepoSuite) TestGetByEmail_NormalizesSpacingAndCaseOnPostgres() {
user := s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
got, err := s.repo.GetByEmail(s.ctx, " legacy@example.com ")
s.Require().NoError(err, "GetByEmail normalized lookup")
s.Require().Equal(user.ID, got.ID)
}
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
func (s *UserRepoSuite) TestExistsByEmail_NormalizesSpacingAndCaseOnPostgres() {
s.mustCreateUser(&service.User{Email: " Legacy@Example.com "})
exists, err := s.repo.ExistsByEmail(s.ctx, " LEGACY@example.com ")
s.Require().NoError(err, "ExistsByEmail normalized lookup")
s.Require().True(exists)
}
func (s *UserRepoSuite) TestUpdate() {
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
......@@ -140,6 +160,30 @@ func (s *UserRepoSuite) TestUpdate() {
s.Require().Equal("updated", updated.Username)
}
func (s *UserRepoSuite) TestUpdateIgnoresNoRowsFromConflictingEmailIdentityUpsert() {
user := s.mustCreateUser(&service.User{Email: "update-existing-identity@test.com", Username: "original"})
identityCount, err := s.client.AuthIdentity.Query().
Where(
authidentity.UserIDEQ(user.ID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("update-existing-identity@test.com"),
).
Count(s.ctx)
s.Require().NoError(err)
s.Require().Equal(1, identityCount)
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
got.Username = "updated"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update should tolerate ON CONFLICT DO NOTHING returning no rows")
updated, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal("updated", updated.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
......@@ -150,6 +194,39 @@ func (s *UserRepoSuite) TestDelete() {
s.Require().Error(err, "expected error after delete")
}
func (s *UserRepoSuite) TestDeleteRemovesAuthIdentitiesAndChannels() {
user := s.mustCreateUser(&service.User{Email: "delete-oauth@test.com"})
identity, err := s.client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("delete-oauth-subject").
Save(s.ctx)
s.Require().NoError(err)
_, err = s.client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType("wechat").
SetProviderKey("wechat").
SetChannel("open").
SetChannelAppID("app-id").
SetChannelSubject("openid-123").
Save(s.ctx)
s.Require().NoError(err)
err = s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err)
identityCount, err := s.client.AuthIdentity.Query().Where(authidentity.UserIDEQ(user.ID)).Count(s.ctx)
s.Require().NoError(err)
s.Require().Zero(identityCount)
channelCount, err := s.client.AuthIdentityChannel.Query().Where(authidentitychannel.IdentityIDEQ(identity.ID)).Count(s.ctx)
s.Require().NoError(err)
s.Require().Zero(channelCount)
}
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
......
......@@ -4,11 +4,30 @@ package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (s *UserRepoSuite) mustInsertUsageLog(userID int64, createdAt time.Time) {
s.T().Helper()
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-log-account"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: userID})
_, err := integrationDB.ExecContext(
s.ctx,
`INSERT INTO usage_logs (user_id, api_key_id, account_id, model, input_tokens, output_tokens, total_cost, actual_cost, created_at)
VALUES ($1, $2, $3, 'gpt-test', 1, 1, 0.01, 0.01, $4)`,
userID,
apiKey.ID,
account.ID,
createdAt.UTC(),
)
s.Require().NoError(err)
}
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
......@@ -36,4 +55,110 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
s.Require().Equal(first.ID, users[1].ID)
}
func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
created := s.mustCreateUser(&service.User{
Email: "identity-meta@example.com",
SignupSource: "linuxdo",
LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
})
got, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err)
s.Require().Equal("linuxdo", got.SignupSource)
s.Require().NotNil(got.LastLoginAt)
s.Require().NotNil(got.LastActiveAt)
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
}
func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
created.SignupSource = "oidc"
created.LastLoginAt = &lastLoginAt
created.LastActiveAt = &lastActiveAt
s.Require().NoError(s.repo.Update(s.ctx, created))
got, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err)
s.Require().Equal("oidc", got.SignupSource)
s.Require().NotNil(got.LastLoginAt)
s.Require().NotNil(got.LastActiveAt)
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
}
func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "last_active_at",
SortOrder: "asc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 3)
s.Require().Equal("earlier-active@example.com", users[0].Email)
s.Require().Equal("later-active@example.com", users[1].Email)
s.Require().Equal("nil-active@example.com", users[2].Email)
}
func (s *UserRepoSuite) TestGetLatestUsedAtByUserIDs_UsesUsageLogs() {
older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Second)
newer := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Second)
userWithUsage := s.mustCreateUser(&service.User{Email: "usage-source@example.com"})
userWithoutUsage := s.mustCreateUser(&service.User{Email: "usage-missing@example.com"})
s.mustInsertUsageLog(userWithUsage.ID, older)
s.mustInsertUsageLog(userWithUsage.ID, newer)
got, err := s.repo.GetLatestUsedAtByUserIDs(s.ctx, []int64{userWithUsage.ID, userWithoutUsage.ID})
s.Require().NoError(err)
s.Require().Contains(got, userWithUsage.ID)
s.Require().NotContains(got, userWithoutUsage.ID)
s.Require().NotNil(got[userWithUsage.ID])
s.Require().True(got[userWithUsage.ID].Equal(newer))
}
func (s *UserRepoSuite) TestListWithFilters_SortByLastUsedAtDesc_UsesUsageLogsNotLastActiveAt() {
lastUsedOlder := time.Now().Add(-6 * time.Hour).UTC().Truncate(time.Second)
lastUsedNewer := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Second)
lastActiveVeryRecent := time.Now().Add(-10 * time.Minute).UTC().Truncate(time.Second)
nilUsage := s.mustCreateUser(&service.User{Email: "nil-last-used@example.com"})
wrongSource := s.mustCreateUser(&service.User{
Email: "active-not-usage@example.com",
LastActiveAt: &lastActiveVeryRecent,
})
rightSource := s.mustCreateUser(&service.User{Email: "usage-wins@example.com"})
s.mustInsertUsageLog(wrongSource.ID, lastUsedOlder)
s.mustInsertUsageLog(rightSource.ID, lastUsedNewer)
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "last_used_at",
SortOrder: "desc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 3)
s.Require().Equal(rightSource.ID, users[0].ID)
s.Require().Equal(wrongSource.ID, users[1].ID)
s.Require().Equal(nilUsage.ID, users[2].ID)
}
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
......@@ -50,6 +50,7 @@ func TestAPIContracts(t *testing.T) {
"data": {
"id": 1,
"email": "alice@example.com",
"email_bound": true,
"username": "alice",
"role": "user",
"balance": 12.5,
......@@ -63,6 +64,120 @@ func TestAPIContracts(t *testing.T) {
"balance_notify_threshold": null,
"balance_notify_extra_emails": null,
"total_recharged": 0,
"linuxdo_bound": false,
"oidc_bound": false,
"wechat_bound": false,
"identities": {
"email": {
"provider": "email",
"provider_key": "email",
"bound": true,
"bound_count": 1,
"can_bind": false,
"can_unbind": false,
"display_name": "alice@example.com",
"subject_hint": "a***e@example.com",
"note": "Primary account email is managed from the profile form."
},
"linuxdo": {
"provider": "linuxdo",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"identity_bindings": {
"email": {
"provider": "email",
"provider_key": "email",
"bound": true,
"bound_count": 1,
"can_bind": false,
"can_unbind": false,
"display_name": "alice@example.com",
"subject_hint": "a***e@example.com",
"note": "Primary account email is managed from the profile form."
},
"linuxdo": {
"provider": "linuxdo",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"auth_bindings": {
"email": {
"provider": "email",
"provider_key": "email",
"bound": true,
"bound_count": 1,
"can_bind": false,
"can_unbind": false,
"display_name": "alice@example.com",
"subject_hint": "a***e@example.com",
"note": "Primary account email is managed from the profile form."
},
"linuxdo": {
"provider": "linuxdo",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"oidc": {
"provider": "oidc",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/oidc/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
},
"wechat": {
"provider": "wechat",
"bound": false,
"bound_count": 0,
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"run_mode": "standard"
}
}`,
......@@ -479,7 +594,7 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "false",
service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
......@@ -500,10 +615,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyTableDefaultPageSize: "20",
service.SettingKeyTablePageSizeOptions: "[10,20,50,100]",
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60",
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60",
service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay,
service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat,
service.SettingPaymentVisibleMethodAlipayEnabled: "true",
service.SettingPaymentVisibleMethodWxpayEnabled: "false",
"openai_advanced_scheduler_enabled": "true",
})
},
method: http.MethodGet,
......@@ -549,7 +669,7 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": false,
"oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120,
......@@ -567,6 +687,27 @@ func TestAPIContracts(t *testing.T) {
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
"auth_source_default_email_balance": 0,
"auth_source_default_email_concurrency": 5,
"auth_source_default_email_subscriptions": [],
"auth_source_default_email_grant_on_signup": false,
"auth_source_default_email_grant_on_first_bind": false,
"auth_source_default_linuxdo_balance": 0,
"auth_source_default_linuxdo_concurrency": 5,
"auth_source_default_linuxdo_subscriptions": [],
"auth_source_default_linuxdo_grant_on_signup": false,
"auth_source_default_linuxdo_grant_on_first_bind": false,
"auth_source_default_oidc_balance": 0,
"auth_source_default_oidc_concurrency": 5,
"auth_source_default_oidc_subscriptions": [],
"auth_source_default_oidc_grant_on_signup": false,
"auth_source_default_oidc_grant_on_first_bind": false,
"auth_source_default_wechat_balance": 0,
"auth_source_default_wechat_concurrency": 5,
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
"default_subscriptions": [],
......@@ -592,6 +733,11 @@ func TestAPIContracts(t *testing.T) {
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
"payment_visible_method_alipay_source": "easypay_alipay",
"payment_visible_method_wxpay_source": "official_wxpay",
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false,
......@@ -618,7 +764,23 @@ func TestAPIContracts(t *testing.T) {
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": []
"account_quota_notify_emails": [],
"wechat_connect_enabled": false,
"wechat_connect_app_id": "",
"wechat_connect_app_secret_configured": false,
"wechat_connect_mode": "open",
"wechat_connect_open_enabled": false,
"wechat_connect_open_app_id": "",
"wechat_connect_open_app_secret_configured": false,
"wechat_connect_mp_enabled": false,
"wechat_connect_mp_app_id": "",
"wechat_connect_mp_app_secret_configured": false,
"wechat_connect_mobile_enabled": false,
"wechat_connect_mobile_app_id": "",
"wechat_connect_mobile_app_secret_configured": false,
"wechat_connect_redirect_url": "",
"wechat_connect_frontend_redirect_url": "/auth/wechat/callback",
"wechat_connect_scopes": "snsapi_login"
}
}`,
},
......@@ -858,6 +1020,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
return nil, nil
}
func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
......@@ -894,6 +1068,26 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
return errors.New("not implemented")
}
func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
return nil, nil
}
func (r *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
return nil, nil
}
func (r *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
return nil
}
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
......
......@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
......@@ -153,6 +154,18 @@ func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
return nil, nil
}
func (s *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
......@@ -161,6 +174,18 @@ func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
func (s *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserIDs call")
}
func (s *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserID call")
}
func (s *stubUserRepo) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
panic("unexpected UpdateUserLastActiveAt call")
}
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
......@@ -189,6 +214,14 @@ func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) {
panic("unexpected ListUserAuthIdentities call")
}
func (s *stubUserRepo) UnbindUserAuthProvider(context.Context, int64, string) error {
panic("unexpected UnbindUserAuthProvider call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
......
package middleware
import (
"context"
"errors"
"strings"
......@@ -11,11 +12,19 @@ import (
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
return JWTAuthMiddleware(jwtAuth(authService, userService))
return JWTAuthMiddleware(jwtAuth(authService, userService, userService))
}
type jwtUserReader interface {
GetByID(ctx context.Context, id int64) (*service.User, error)
}
type userActivityToucher interface {
TouchLastActiveForUser(ctx context.Context, user *service.User)
}
// jwtAuth JWT认证中间件实现
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
func jwtAuth(authService *service.AuthService, userService jwtUserReader, activityToucher userActivityToucher) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
......@@ -73,6 +82,9 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
if activityToucher != nil {
activityToucher.TouchLastActiveForUser(c.Request.Context(), user)
}
c.Next()
}
......
......@@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -30,6 +31,25 @@ func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, e
return u, nil
}
func (r *stubJWTUserRepo) GetUserAvatar(_ context.Context, _ int64) (*service.UserAvatar, error) {
return nil, nil
}
func (r *stubJWTUserRepo) UpdateUserLastActiveAt(_ context.Context, _ int64, _ time.Time) error {
return nil
}
type recordingActivityToucher struct {
userIDs []int64
}
func (r *recordingActivityToucher) TouchLastActiveForUser(_ context.Context, user *service.User) {
if user == nil {
return
}
r.userIDs = append(r.userIDs, user.ID)
}
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
......@@ -106,6 +126,45 @@ func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
func TestJWTAuth_ValidToken_TouchesLastActive(t *testing.T) {
user := &service.User{
ID: 1,
Email: "test@example.com",
Role: "user",
Status: service.StatusActive,
Concurrency: 5,
TokenVersion: 1,
}
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: map[int64]*service.User{1: user}}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil, nil)
toucher := &recordingActivityToucher{}
r := gin.New()
r.Use(jwtAuth(authSvc, userSvc, toucher))
r.GET("/protected", func(c *gin.Context) {
c.Status(http.StatusOK)
})
token, err := authSvc.GenerateToken(user)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, []int64{1}, toucher.userIDs)
}
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)
......
......@@ -212,6 +212,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment