Commit e9de839d authored by IanShaw027's avatar IanShaw027
Browse files

feat: rebuild auth identity foundation flow

parent fbd0a2e3
...@@ -189,6 +189,7 @@ type PublicSettings struct { ...@@ -189,6 +189,7 @@ type PublicSettings struct {
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"`
SoraClientEnabled bool `json:"sora_client_enabled"` SoraClientEnabled bool `json:"sora_client_enabled"`
......
...@@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) ...@@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string)
// This allows looking up the correct provider instance before verification. // This allows looking up the correct provider instance before verification.
func extractOutTradeNo(rawBody, providerKey string) string { func extractOutTradeNo(rawBody, providerKey string) string {
switch providerKey { switch providerKey {
case payment.TypeEasyPay: case payment.TypeEasyPay, payment.TypeAlipay:
values, err := url.ParseQuery(rawBody) values, err := url.ParseQuery(rawBody)
if err == nil { if err == nil {
return values.Get("out_trade_no") return values.Get("out_trade_no")
......
...@@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) { ...@@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) {
assert.Equal(t, 200, webhookLogTruncateLen) assert.Equal(t, 200, webhookLogTruncateLen)
}) })
} }
func TestExtractOutTradeNo(t *testing.T) {
tests := []struct {
name string
providerKey string
rawBody string
want string
}{
{
name: "easypay query payload",
providerKey: "easypay",
rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS",
want: "sub2_123",
},
{
name: "alipay query payload",
providerKey: "alipay",
rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456",
want: "sub2_456",
},
{
name: "unknown provider",
providerKey: "wxpay",
rawBody: "{}",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey))
})
}
}
...@@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ...@@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
OIDCOAuthProviderName: settings.OIDCOAuthProviderName, OIDCOAuthProviderName: settings.OIDCOAuthProviderName,
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
......
...@@ -34,10 +34,16 @@ type ChangePasswordRequest struct { ...@@ -34,10 +34,16 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload // UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Username *string `json:"username"` Username *string `json:"username"`
AvatarURL *string `json:"avatar_url"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
} }
type userProfileResponse struct {
dto.User
AvatarURL string `json:"avatar_url,omitempty"`
}
// GetProfile handles getting user profile // GetProfile handles getting user profile
// GET /api/v1/users/me // GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) { func (h *UserHandler) GetProfile(c *gin.Context) {
...@@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) { ...@@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
return return
} }
userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID) userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, dto.UserFromService(userData)) response.Success(c, userProfileResponseFromService(userData))
} }
// ChangePassword handles changing user password // ChangePassword handles changing user password
...@@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
svcReq := service.UpdateProfileRequest{ svcReq := service.UpdateProfileRequest{
Username: req.Username, Username: req.Username,
AvatarURL: req.AvatarURL,
BalanceNotifyEnabled: req.BalanceNotifyEnabled, BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold, BalanceNotifyThreshold: req.BalanceNotifyThreshold,
} }
...@@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { ...@@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
return return
} }
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, userProfileResponseFromService(updatedUser))
} }
// SendNotifyEmailCodeRequest represents the request to send notify email verification code // SendNotifyEmailCodeRequest represents the request to send notify email verification code
...@@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { ...@@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
return return
} }
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, userProfileResponseFromService(updatedUser))
} }
// RemoveNotifyEmailRequest represents the request to remove a notify email // RemoveNotifyEmailRequest represents the request to remove a notify email
...@@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { ...@@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
return return
} }
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, userProfileResponseFromService(updatedUser))
} }
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
...@@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { ...@@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
return return
} }
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, userProfileResponseFromService(updatedUser))
}
func userProfileResponseFromService(user *service.User) userProfileResponse {
base := dto.UserFromService(user)
if base == nil {
return userProfileResponse{}
}
return userProfileResponse{
User: *base,
AvatarURL: user.AvatarURL,
}
} }
//go:build unit
package handler
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type userHandlerRepoStub struct {
user *service.User
}
func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil }
func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) {
cloned := *s.user
return &cloned, nil
}
func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error {
cloned := *user
s.user = &cloned
return nil
}
func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil }
func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) {
if s.user == nil || s.user.AvatarURL == "" {
return nil, nil
}
return &service.UserAvatar{
StorageProvider: s.user.AvatarSource,
URL: s.user.AvatarURL,
ContentType: s.user.AvatarMIME,
ByteSize: s.user.AvatarByteSize,
SHA256: s.user.AvatarSHA256,
}, nil
}
func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
s.user.AvatarURL = input.URL
s.user.AvatarSource = input.StorageProvider
s.user.AvatarMIME = input.ContentType
s.user.AvatarByteSize = input.ByteSize
s.user.AvatarSHA256 = input.SHA256
return &service.UserAvatar{
StorageProvider: input.StorageProvider,
URL: input.URL,
ContentType: input.ContentType,
ByteSize: input.ByteSize,
SHA256: input.SHA256,
}, nil
}
func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error {
s.user.AvatarURL = ""
s.user.AvatarSource = ""
s.user.AvatarMIME = ""
s.user.AvatarByteSize = 0
s.user.AvatarSHA256 = ""
return nil
}
func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil }
func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil }
func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 11,
Email: "handler-avatar@example.com",
Username: "handler-avatar",
Role: service.RoleUser,
Status: service.StatusActive,
},
}
handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil)
body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11})
handler.UpdateProfile(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
AvatarURL string `json:"avatar_url"`
Username string `json:"username"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL)
require.Equal(t, "handler-avatar", resp.Data.Username)
}
...@@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se ...@@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
user.FieldBalanceNotifyThreshold, user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails, user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged, user.FieldTotalRecharged,
user.FieldSignupSource,
user.FieldLastLoginAt,
user.FieldLastActiveAt,
) )
}). }).
WithGroup(func(q *dbent.GroupQuery) { WithGroup(func(q *dbent.GroupQuery) {
...@@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User { ...@@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User {
Balance: u.Balance, Balance: u.Balance,
Concurrency: u.Concurrency, Concurrency: u.Concurrency,
Status: u.Status, Status: u.Status,
SignupSource: u.SignupSource,
LastLoginAt: u.LastLoginAt,
LastActiveAt: u.LastActiveAt,
TotpSecretEncrypted: u.TotpSecretEncrypted, TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled, TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt, TotpEnabledAt: u.TotpEnabledAt,
......
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
)
type AuthIdentityMigrationReport struct {
ID int64
ReportType string
ReportKey string
Details map[string]any
CreatedAt time.Time
}
type AuthIdentityMigrationReportQuery struct {
ReportType string
Limit int
Offset int
}
type AuthIdentityMigrationReportSummary struct {
Total int64
ByType map[string]int64
}
func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
limit := query.Limit
if limit <= 0 {
limit = 100
}
rows, err := exec.QueryContext(ctx, `
SELECT id, report_type, report_key, details, created_at
FROM auth_identity_migration_reports
WHERE ($1 = '' OR report_type = $1)
ORDER BY created_at DESC, id DESC
LIMIT $2 OFFSET $3`,
strings.TrimSpace(query.ReportType),
limit,
query.Offset,
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
reports := make([]AuthIdentityMigrationReport, 0)
for rows.Next() {
report, scanErr := scanAuthIdentityMigrationReport(rows)
if scanErr != nil {
return nil, scanErr
}
reports = append(reports, report)
}
if err := rows.Err(); err != nil {
return nil, err
}
return reports, nil
}
func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
rows, err := exec.QueryContext(ctx, `
SELECT id, report_type, report_key, details, created_at
FROM auth_identity_migration_reports
WHERE report_type = $1 AND report_key = $2
LIMIT 1`,
strings.TrimSpace(reportType),
strings.TrimSpace(reportKey),
)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, sql.ErrNoRows
}
report, err := scanAuthIdentityMigrationReport(rows)
if err != nil {
return nil, err
}
return &report, rows.Err()
}
func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
rows, err := exec.QueryContext(ctx, `
SELECT report_type, COUNT(*)
FROM auth_identity_migration_reports
GROUP BY report_type
ORDER BY report_type ASC`)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
summary := &AuthIdentityMigrationReportSummary{
ByType: make(map[string]int64),
}
for rows.Next() {
var reportType string
var count int64
if err := rows.Scan(&reportType, &count); err != nil {
return nil, err
}
summary.ByType[reportType] = count
summary.Total += count
}
if err := rows.Err(); err != nil {
return nil, err
}
return summary, nil
}
func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) {
var (
report AuthIdentityMigrationReport
details []byte
)
if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil {
return AuthIdentityMigrationReport{}, err
}
report.Details = map[string]any{}
if len(details) > 0 {
if err := json.Unmarshal(details, &report.Details); err != nil {
return AuthIdentityMigrationReport{}, err
}
}
return report, nil
}
package repository
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"unsafe"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
)
var (
ErrAuthIdentityOwnershipConflict = infraerrors.Conflict(
"AUTH_IDENTITY_OWNERSHIP_CONFLICT",
"auth identity already belongs to another user",
)
ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict(
"AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT",
"auth identity channel already belongs to another user",
)
)
type ProviderGrantReason string
const (
ProviderGrantReasonSignup ProviderGrantReason = "signup"
ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind"
)
type AuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type AuthIdentityChannelKey struct {
ProviderType string
ProviderKey string
Channel string
ChannelAppID string
ChannelSubject string
}
type CreateAuthIdentityInput struct {
UserID int64
Canonical AuthIdentityKey
Channel *AuthIdentityChannelKey
Issuer *string
VerifiedAt *time.Time
Metadata map[string]any
ChannelMetadata map[string]any
}
type BindAuthIdentityInput = CreateAuthIdentityInput
type CreateAuthIdentityResult struct {
Identity *dbent.AuthIdentity
Channel *dbent.AuthIdentityChannel
}
func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey {
if r == nil || r.Identity == nil {
return AuthIdentityKey{}
}
return AuthIdentityKey{
ProviderType: r.Identity.ProviderType,
ProviderKey: r.Identity.ProviderKey,
ProviderSubject: r.Identity.ProviderSubject,
}
}
func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey {
if r == nil || r.Channel == nil {
return nil
}
return &AuthIdentityChannelKey{
ProviderType: r.Channel.ProviderType,
ProviderKey: r.Channel.ProviderKey,
Channel: r.Channel.Channel,
ChannelAppID: r.Channel.ChannelAppID,
ChannelSubject: r.Channel.ChannelSubject,
}
}
type UserAuthIdentityLookup struct {
User *dbent.User
Identity *dbent.AuthIdentity
Channel *dbent.AuthIdentityChannel
}
type ProviderGrantRecordInput struct {
UserID int64
ProviderType string
GrantReason ProviderGrantReason
}
type IdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID *int64
AdoptDisplayName bool
AdoptAvatar bool
}
type sqlQueryExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error {
if dbent.TxFromContext(ctx) != nil {
return fn(ctx)
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := fn(txCtx); err != nil {
return err
}
return tx.Commit()
}
func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) {
client := clientFromContext(ctx, r.client)
create := client.AuthIdentity.Create().
SetUserID(input.UserID).
SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)).
SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)).
SetMetadata(copyMetadata(input.Metadata)).
SetNillableIssuer(input.Issuer).
SetNillableVerifiedAt(input.VerifiedAt)
identity, err := create.Save(ctx)
if err != nil {
return nil, err
}
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
SetChannel(strings.TrimSpace(input.Channel.Channel)).
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
SetMetadata(copyMetadata(input.ChannelMetadata)).
Save(ctx)
if err != nil {
return nil, err
}
}
return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil
}
func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) {
identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)),
).
WithUser().
Only(ctx)
if err != nil {
return nil, err
}
return &UserAuthIdentityLookup{
User: identity.Edges.User,
Identity: identity,
}, nil
}
func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) {
channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)),
authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)),
).
WithIdentity(func(q *dbent.AuthIdentityQuery) {
q.WithUser()
}).
Only(ctx)
if err != nil {
return nil, err
}
return &UserAuthIdentityLookup{
User: channel.Edges.Identity.Edges.User,
Identity: channel.Edges.Identity,
Channel: channel,
}, nil
}
func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) {
var result *CreateAuthIdentityResult
err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error {
client := clientFromContext(txCtx, r.client)
canonical := input.Canonical
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)),
).
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
return err
}
if identity != nil && identity.UserID != input.UserID {
return ErrAuthIdentityOwnershipConflict
}
if identity == nil {
identity, err = client.AuthIdentity.Create().
SetUserID(input.UserID).
SetProviderType(strings.TrimSpace(canonical.ProviderType)).
SetProviderKey(strings.TrimSpace(canonical.ProviderKey)).
SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)).
SetMetadata(copyMetadata(input.Metadata)).
SetNillableIssuer(input.Issuer).
SetNillableVerifiedAt(input.VerifiedAt).
Save(txCtx)
if err != nil {
return err
}
} else {
update := client.AuthIdentity.UpdateOneID(identity.ID)
if input.Metadata != nil {
update = update.SetMetadata(copyMetadata(input.Metadata))
}
if input.Issuer != nil {
update = update.SetIssuer(strings.TrimSpace(*input.Issuer))
}
if input.VerifiedAt != nil {
update = update.SetVerifiedAt(*input.VerifiedAt)
}
identity, err = update.Save(txCtx)
if err != nil {
return err
}
}
var channel *dbent.AuthIdentityChannel
if input.Channel != nil {
channel, err = client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)),
authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)),
authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)),
authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)),
authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)),
).
WithIdentity().
Only(txCtx)
if err != nil && !dbent.IsNotFound(err) {
return err
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID {
return ErrAuthIdentityChannelOwnershipConflict
}
if channel == nil {
channel, err = client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(strings.TrimSpace(input.Channel.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)).
SetChannel(strings.TrimSpace(input.Channel.Channel)).
SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)).
SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)).
SetMetadata(copyMetadata(input.ChannelMetadata)).
Save(txCtx)
if err != nil {
return err
}
} else {
update := client.AuthIdentityChannel.UpdateOneID(channel.ID).
SetIdentityID(identity.ID)
if input.ChannelMetadata != nil {
update = update.SetMetadata(copyMetadata(input.ChannelMetadata))
}
channel, err = update.Save(txCtx)
if err != nil {
return err
}
}
}
result = &CreateAuthIdentityResult{Identity: identity, Channel: channel}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return false, fmt.Errorf("sql executor is not configured")
}
result, err := exec.ExecContext(ctx, `
INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
input.UserID,
strings.TrimSpace(input.ProviderType),
string(input.GrantReason),
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
client := clientFromContext(ctx, r.client)
current, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
now := time.Now().UTC()
if current == nil {
create := client.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(now)
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := client.IdentityAdoptionDecision.UpdateOneID(current.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
}
return update.Save(ctx)
}
func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) {
return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)).
Only(ctx)
}
func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error {
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
SetLastLoginAt(loginAt).
Save(ctx)
return err
}
func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
_, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID).
SetLastActiveAt(activeAt).
Save(ctx)
return err
}
func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return nil, err
}
rows, err := exec.QueryContext(ctx, `
SELECT storage_provider, storage_key, url, content_type, byte_size, sha256
FROM user_avatars
WHERE user_id = $1`, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return nil, rows.Err()
}
var avatar service.UserAvatar
if err := rows.Scan(
&avatar.StorageProvider,
&avatar.StorageKey,
&avatar.URL,
&avatar.ContentType,
&avatar.ByteSize,
&avatar.SHA256,
); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return &avatar, nil
}
func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return nil, err
}
_, err = exec.ExecContext(ctx, `
INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (user_id) DO UPDATE SET
storage_provider = EXCLUDED.storage_provider,
storage_key = EXCLUDED.storage_key,
url = EXCLUDED.url,
content_type = EXCLUDED.content_type,
byte_size = EXCLUDED.byte_size,
sha256 = EXCLUDED.sha256,
updated_at = NOW()`,
userID,
strings.TrimSpace(input.StorageProvider),
strings.TrimSpace(input.StorageKey),
strings.TrimSpace(input.URL),
strings.TrimSpace(input.ContentType),
input.ByteSize,
strings.TrimSpace(input.SHA256),
)
if err != nil {
return nil, err
}
return &service.UserAvatar{
StorageProvider: strings.TrimSpace(input.StorageProvider),
StorageKey: strings.TrimSpace(input.StorageKey),
URL: strings.TrimSpace(input.URL),
ContentType: strings.TrimSpace(input.ContentType),
ByteSize: input.ByteSize,
SHA256: strings.TrimSpace(input.SHA256),
}, nil
}
func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error {
exec, err := r.userProfileIdentitySQL(ctx)
if err != nil {
return err
}
_, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID)
return err
}
func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error {
if user == nil {
return nil
}
avatar, err := r.GetUserAvatar(ctx, user.ID)
if err != nil {
return err
}
if avatar == nil {
return nil
}
user.AvatarURL = avatar.URL
user.AvatarSource = avatar.StorageProvider
user.AvatarMIME = avatar.ContentType
user.AvatarByteSize = avatar.ByteSize
user.AvatarSHA256 = avatar.SHA256
return nil
}
func copyMetadata(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor {
if tx := dbent.TxFromContext(ctx); tx != nil {
if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil {
return exec
}
}
if fallback != nil {
return fallback
}
return sqlExecutorFromEntClient(client)
}
func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) {
exec := txAwareSQLExecutor(ctx, r.sql, r.client)
if exec == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
return exec, nil
}
func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor {
if client == nil {
return nil
}
clientValue := reflect.ValueOf(client).Elem()
configValue := clientValue.FieldByName("config")
driverValue := configValue.FieldByName("driver")
if !driverValue.IsValid() {
return nil
}
driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface()
exec, ok := driver.(sqlQueryExecutor)
if !ok {
return nil
}
return exec
}
//go:build integration
package repository
import (
"context"
"errors"
"fmt"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type UserProfileIdentityRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *userRepository
}
func TestUserProfileIdentityRepoSuite(t *testing.T) {
suite.Run(t, new(UserProfileIdentityRepoSuite))
}
func (s *UserProfileIdentityRepoSuite) SetupTest() {
s.ctx = context.Background()
s.client = testEntClient(s.T())
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
_, err := integrationDB.ExecContext(s.ctx, `
TRUNCATE TABLE
identity_adoption_decisions,
auth_identity_channels,
auth_identities,
pending_auth_sessions,
auth_identity_migration_reports,
user_provider_default_grants,
user_avatars
RESTART IDENTITY`)
s.Require().NoError(err)
}
func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User {
s.T().Helper()
user, err := s.client.User.Create().
SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())).
SetPasswordHash("test-password-hash").
SetRole("user").
SetStatus("active").
Save(s.ctx)
s.Require().NoError(err)
return user
}
func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession {
s.T().Helper()
session, err := s.client.PendingAuthSession.Create().
SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())).
SetIntent("bind_current_user").
SetProviderType(key.ProviderType).
SetProviderKey(key.ProviderKey).
SetProviderSubject(key.ProviderSubject).
SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)).
SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}).
SetLocalFlowState(map[string]any{"step": "pending"}).
Save(s.ctx)
s.Require().NoError(err)
return session
}
func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() {
user := s.mustCreateUser("canonical-channel")
verifiedAt := time.Now().UTC().Truncate(time.Second)
created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-123",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
Channel: "mp",
ChannelAppID: "wx-app",
ChannelSubject: "openid-123",
},
Issuer: stringPtr("https://issuer.example"),
VerifiedAt: &verifiedAt,
Metadata: map[string]any{"unionid": "union-123"},
ChannelMetadata: map[string]any{"openid": "openid-123"},
})
s.Require().NoError(err)
s.Require().NotNil(created.Identity)
s.Require().NotNil(created.Channel)
canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef())
s.Require().NoError(err)
s.Require().Equal(user.ID, canonical.User.ID)
s.Require().Equal(created.Identity.ID, canonical.Identity.ID)
s.Require().Equal("union-123", canonical.Identity.ProviderSubject)
channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef())
s.Require().NoError(err)
s.Require().Equal(user.ID, channel.User.ID)
s.Require().Equal(created.Identity.ID, channel.Identity.ID)
s.Require().Equal(created.Channel.ID, channel.Channel.ID)
}
func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() {
owner := s.mustCreateUser("owner")
other := s.mustCreateUser("other")
first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: owner.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
Metadata: map[string]any{"username": "first"},
ChannelMetadata: map[string]any{"scope": "read"},
})
s.Require().NoError(err)
second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: owner.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
Metadata: map[string]any{"username": "second"},
ChannelMetadata: map[string]any{"scope": "write"},
})
s.Require().NoError(err)
s.Require().Equal(first.Identity.ID, second.Identity.ID)
s.Require().Equal(first.Channel.ID, second.Channel.ID)
s.Require().Equal("second", second.Identity.Metadata["username"])
s.Require().Equal("write", second.Channel.Metadata["scope"])
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: other.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict)
_, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{
UserID: other.ID,
Canonical: AuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-2",
},
Channel: &AuthIdentityChannelKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
Channel: "oauth",
ChannelAppID: "linuxdo-web",
ChannelSubject: "subject-1",
},
})
s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict)
}
func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() {
user := s.mustCreateUser("tx-rollback")
expectedErr := errors.New("rollback")
err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error {
_, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-rollback",
},
})
s.Require().NoError(err)
inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "oidc",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().True(inserted)
return expectedErr
})
s.Require().ErrorIs(err, expectedErr)
_, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-rollback",
})
s.Require().True(dbent.IsNotFound(err))
var count int
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT COUNT(*)
FROM user_provider_default_grants
WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`,
user.ID,
"oidc",
string(ProviderGrantReasonFirstBind),
).Scan(&count))
s.Require().Zero(count)
}
func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() {
user := s.mustCreateUser("grant")
inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().True(inserted)
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonFirstBind,
})
s.Require().NoError(err)
s.Require().False(inserted)
inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{
UserID: user.ID,
ProviderType: "wechat",
GrantReason: ProviderGrantReasonSignup,
})
s.Require().NoError(err)
s.Require().True(inserted)
var count int
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT COUNT(*)
FROM user_provider_default_grants
WHERE user_id = $1 AND provider_type = $2`,
user.ID,
"wechat",
).Scan(&count))
s.Require().Equal(2, count)
}
func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() {
user := s.mustCreateUser("adoption")
identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{
UserID: user.ID,
Canonical: AuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption",
},
})
s.Require().NoError(err)
session := s.mustCreatePendingAuthSession(identity.IdentityRef())
first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
s.Require().NoError(err)
s.Require().True(first.AdoptDisplayName)
s.Require().False(first.AdoptAvatar)
s.Require().Nil(first.IdentityID)
second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.Identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
})
s.Require().NoError(err)
s.Require().Equal(first.ID, second.ID)
s.Require().NotNil(second.IdentityID)
s.Require().Equal(identity.Identity.ID, *second.IdentityID)
s.Require().True(second.AdoptAvatar)
loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID)
s.Require().NoError(err)
s.Require().Equal(second.ID, loaded.ID)
s.Require().Equal(identity.Identity.ID, *loaded.IdentityID)
}
func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() {
user := s.mustCreateUser("avatar")
inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "inline",
URL: "data:image/png;base64,QUJD",
ContentType: "image/png",
ByteSize: 3,
SHA256: "902fbdd2b1df0c4f70b4a5d23525e932",
})
s.Require().NoError(err)
s.Require().Equal("inline", inlineAvatar.StorageProvider)
s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL)
loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(loadedAvatar)
s.Require().Equal("image/png", loadedAvatar.ContentType)
s.Require().Equal(3, loadedAvatar.ByteSize)
_, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/avatar.png",
})
s.Require().NoError(err)
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().NotNil(loadedAvatar)
s.Require().Equal("remote_url", loadedAvatar.StorageProvider)
s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL)
s.Require().Zero(loadedAvatar.ByteSize)
s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID))
loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Nil(loadedAvatar)
}
func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() {
_, err := integrationDB.ExecContext(s.ctx, `
INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at)
VALUES
('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'),
('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'),
('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`)
s.Require().NoError(err)
summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx)
s.Require().NoError(err)
s.Require().Equal(int64(3), summary.Total)
s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"])
s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"])
reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{
ReportType: "wechat_openid_only_requires_remediation",
Limit: 10,
})
s.Require().NoError(err)
s.Require().Len(reports, 2)
s.Require().Equal("u-2", reports[0].ReportKey)
s.Require().Equal(float64(2), reports[0].Details["user_id"])
report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3")
s.Require().NoError(err)
s.Require().Equal("u-3", report.ReportKey)
s.Require().Equal(float64(3), report.Details["user_id"])
}
func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() {
user := s.mustCreateUser("activity")
loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
activeAt := loginAt.Add(5 * time.Minute)
s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt))
s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt))
var storedLoginAt sqlNullTime
var storedActiveAt sqlNullTime
s.Require().NoError(integrationDB.QueryRowContext(s.ctx, `
SELECT last_login_at, last_active_at
FROM users
WHERE id = $1`,
user.ID,
).Scan(&storedLoginAt, &storedActiveAt))
s.Require().True(storedLoginAt.Valid)
s.Require().True(storedActiveAt.Valid)
s.Require().True(storedLoginAt.Time.Equal(loginAt))
s.Require().True(storedActiveAt.Time.Equal(activeAt))
}
type sqlNullTime struct {
Time time.Time
Valid bool
}
func (t *sqlNullTime) Scan(value any) error {
switch v := value.(type) {
case time.Time:
t.Time = v
t.Valid = true
return nil
case nil:
t.Time = time.Time{}
t.Valid = false
return nil
default:
return fmt.Errorf("unsupported scan type %T", value)
}
}
func stringPtr(v string) *string {
return &v
}
...@@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error ...@@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)).
SetNillableLastLoginAt(userIn.LastLoginAt).
SetNillableLastActiveAt(userIn.LastActiveAt).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists) return translatePersistenceError(err, nil, service.ErrEmailExists)
...@@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error ...@@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged) SetTotalRecharged(userIn.TotalRecharged)
if userIn.SignupSource != "" {
updateOp = updateOp.SetSignupSource(userIn.SignupSource)
}
if userIn.LastLoginAt != nil {
updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt)
}
if userIn.LastActiveAt != nil {
updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt)
}
if userIn.BalanceNotifyThreshold == nil { if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold() updateOp = updateOp.ClearBalanceNotifyThreshold()
} }
...@@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) ...@@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
var field string var field string
defaultField := true defaultField := true
nullsLastField := false
switch sortBy { switch sortBy {
case "email": case "email":
field = dbuser.FieldEmail field = dbuser.FieldEmail
...@@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) ...@@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
case "created_at": case "created_at":
field = dbuser.FieldCreatedAt field = dbuser.FieldCreatedAt
defaultField = false defaultField = false
case "last_login_at":
field = dbuser.FieldLastLoginAt
defaultField = false
nullsLastField = true
case "last_active_at":
field = dbuser.FieldLastActiveAt
defaultField = false
nullsLastField = true
default: default:
field = dbuser.FieldID field = dbuser.FieldID
} }
...@@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) ...@@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector)
if defaultField && field == dbuser.FieldID { if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)} return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
} }
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(),
dbent.Asc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)} return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
} }
if defaultField && field == dbuser.FieldID { if defaultField && field == dbuser.FieldID {
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)} return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
} }
if nullsLastField {
return []func(*entsql.Selector){
entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(),
dbent.Desc(dbuser.FieldID),
}
}
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)} return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
} }
...@@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { ...@@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
return return
} }
dst.ID = src.ID dst.ID = src.ID
dst.SignupSource = src.SignupSource
dst.LastLoginAt = src.LastLoginAt
dst.LastActiveAt = src.LastActiveAt
dst.CreatedAt = src.CreatedAt dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt dst.UpdatedAt = src.UpdatedAt
} }
func userSignupSourceOrDefault(signupSource string) string {
signupSource = strings.TrimSpace(signupSource)
if signupSource == "" {
return "email"
}
return signupSource
}
// marshalExtraEmails serializes notify email entries to JSON for storage. // marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string { func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries) return service.MarshalNotifyEmails(entries)
......
...@@ -4,6 +4,7 @@ package repository ...@@ -4,6 +4,7 @@ package repository
import ( import (
"testing" "testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
...@@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() { ...@@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
s.Require().Equal(first.ID, users[1].ID) s.Require().Equal(first.ID, users[1].ID)
} }
func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() {
lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond)
lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond)
created := s.mustCreateUser(&service.User{
Email: "identity-meta@example.com",
SignupSource: "github",
LastLoginAt: &lastLoginAt,
LastActiveAt: &lastActiveAt,
})
got, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err)
s.Require().Equal("github", got.SignupSource)
s.Require().NotNil(got.LastLoginAt)
s.Require().NotNil(got.LastActiveAt)
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
}
func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() {
created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"})
lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond)
lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond)
created.SignupSource = "oidc"
created.LastLoginAt = &lastLoginAt
created.LastActiveAt = &lastActiveAt
s.Require().NoError(s.repo.Update(s.ctx, created))
got, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err)
s.Require().Equal("oidc", got.SignupSource)
s.Require().NotNil(got.LastLoginAt)
s.Require().NotNil(got.LastActiveAt)
s.Require().True(got.LastLoginAt.Equal(lastLoginAt))
s.Require().True(got.LastActiveAt.Equal(lastActiveAt))
}
func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() {
older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond)
newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond)
s.mustCreateUser(&service.User{Email: "nil-login@example.com"})
s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older})
s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "last_login_at",
SortOrder: "desc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 3)
s.Require().Equal("newer-login@example.com", users[0].Email)
s.Require().Equal("older-login@example.com", users[1].Email)
s.Require().Equal("nil-login@example.com", users[2].Email)
}
func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() {
earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond)
later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond)
s.mustCreateUser(&service.User{Email: "nil-active@example.com"})
s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later})
s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
Page: 1,
PageSize: 10,
SortBy: "last_active_at",
SortOrder: "asc",
}, service.UserListFilters{})
s.Require().NoError(err)
s.Require().Len(users, 3)
s.Require().Equal("earlier-active@example.com", users[0].Email)
s.Require().Equal("later-active@example.com", users[1].Email)
s.Require().Equal("nil-active@example.com", users[2].Email)
}
func TestUserRepoSortSuiteSmoke(_ *testing.T) {} func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
...@@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyOIDCConnectRedirectURL: "", service.SettingKeyOIDCConnectRedirectURL: "",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "false", service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true", service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120", service.SettingKeyOIDCConnectClockSkewSeconds: "120",
...@@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) { ...@@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) {
"oidc_connect_redirect_url": "", "oidc_connect_redirect_url": "",
"oidc_connect_frontend_redirect_url": "/auth/oidc/callback", "oidc_connect_frontend_redirect_url": "/auth/oidc/callback",
"oidc_connect_token_auth_method": "client_secret_post", "oidc_connect_token_auth_method": "client_secret_post",
"oidc_connect_use_pkce": false, "oidc_connect_use_pkce": true,
"oidc_connect_validate_id_token": true, "oidc_connect_validate_id_token": true,
"oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256",
"oidc_connect_clock_skew_seconds": 120, "oidc_connect_clock_skew_seconds": 120,
......
...@@ -64,12 +64,26 @@ func RegisterAuthRoutes( ...@@ -64,12 +64,26 @@ func RegisterAuthRoutes(
}), h.Auth.ResetPassword) }), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
auth.POST("/oauth/pending/exchange",
rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.ExchangePendingOAuthCompletion,
)
auth.POST("/oauth/linuxdo/complete-registration", auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose, FailureMode: middleware.RateLimitFailClose,
}), }),
h.Auth.CompleteLinuxDoOAuthRegistration, h.Auth.CompleteLinuxDoOAuthRegistration,
) )
auth.POST("/oauth/wechat/complete-registration",
rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteWeChatOAuthRegistration,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration", auth.POST("/oauth/oidc/complete-registration",
......
...@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro ...@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
} }
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected") panic("unexpected")
} }
......
...@@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error { ...@@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr return s.deleteErr
} }
func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
panic("unexpected GetUserAvatar call")
}
func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call") panic("unexpected List call")
} }
......
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
)
const (
defaultPendingAuthTTL = 15 * time.Minute
defaultPendingAuthCompletionTTL = 5 * time.Minute
)
type PendingAuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type CreatePendingAuthSessionInput struct {
SessionToken string
Intent string
Identity PendingAuthIdentityKey
TargetUserID *int64
RedirectTo string
ResolvedEmail string
RegistrationPasswordHash string
BrowserSessionKey string
UpstreamIdentityClaims map[string]any
LocalFlowState map[string]any
ExpiresAt time.Time
}
type IssuePendingAuthCompletionCodeInput struct {
PendingAuthSessionID int64
BrowserSessionKey string
TTL time.Duration
}
type IssuePendingAuthCompletionCodeResult struct {
Code string
ExpiresAt time.Time
}
type PendingIdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID *int64
AdoptDisplayName bool
AdoptAvatar bool
}
type AuthPendingIdentityService struct {
entClient *dbent.Client
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient}
}
func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken := strings.TrimSpace(input.SessionToken)
if sessionToken == "" {
var err error
sessionToken, err = randomOpaqueToken(24)
if err != nil {
return nil, err
}
}
expiresAt := input.ExpiresAt.UTC()
if expiresAt.IsZero() {
expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
}
create := s.entClient.PendingAuthSession.Create().
SetSessionToken(sessionToken).
SetIntent(strings.TrimSpace(input.Intent)).
SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
SetExpiresAt(expiresAt)
if input.TargetUserID != nil {
create = create.SetTargetUserID(*input.TargetUserID)
}
return create.Save(ctx)
}
func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
code, err := randomOpaqueToken(24)
if err != nil {
return nil, err
}
ttl := input.TTL
if ttl <= 0 {
ttl = defaultPendingAuthCompletionTTL
}
expiresAt := time.Now().UTC().Add(ttl)
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeHash(hashPendingAuthCode(code)).
SetCompletionCodeExpiresAt(expiresAt)
if strings.TrimSpace(input.BrowserSessionKey) != "" {
update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
}
if _, err := update.Save(ctx); err != nil {
return nil, err
}
return &IssuePendingAuthCompletionCodeResult{
Code: code,
ExpiresAt: expiresAt,
}, nil
}
func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthCodeInvalid
}
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
}
func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
}
func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken = strings.TrimSpace(sessionToken)
if sessionToken == "" {
return nil, ErrPendingAuthSessionNotFound
}
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) consumeSession(
ctx context.Context,
session *dbent.PendingAuthSession,
browserSessionKey string,
expiredErr error,
consumedErr error,
) (*dbent.PendingAuthSession, error) {
if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
return nil, err
}
now := time.Now().UTC()
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetConsumedAt(now).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx)
if err != nil {
return nil, err
}
return updated, nil
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
if session == nil {
return ErrPendingAuthSessionNotFound
}
now := time.Now().UTC()
if session.ConsumedAt != nil {
return consumedErr
}
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
return expiredErr
}
if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
return expiredErr
}
if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
return ErrPendingAuthBrowserMismatch
}
return nil
}
func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
existing, err := s.entClient.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
}
return update.Save(ctx)
}
func copyPendingMap(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func randomOpaqueToken(byteLen int) (string, error) {
if byteLen <= 0 {
byteLen = 16
}
buf := make([]byte, byteLen)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func hashPendingAuthCode(code string) string {
sum := sha256.Sum256([]byte(code))
return hex.EncodeToString(sum[:])
}
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return NewAuthPendingIdentityService(client), client
}
func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
targetUser, err := client.User.Create().
SetEmail("pending-target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-123",
},
TargetUserID: &targetUser.ID,
RedirectTo: "/profile",
ResolvedEmail: "user@example.com",
BrowserSessionKey: "browser-1",
UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
LocalFlowState: map[string]any{"step": "email_required"},
})
require.NoError(t, err)
require.NotEmpty(t, session.SessionToken)
require.Equal(t, "bind_current_user", session.Intent)
require.Equal(t, "wechat", session.ProviderType)
require.NotNil(t, session.TargetUserID)
require.Equal(t, targetUser.ID, *session.TargetUserID)
require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
require.Equal(t, "email_required", session.LocalFlowState["step"])
}
func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expected",
UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
LocalFlowState: map[string]any{"step": "pending"},
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expected",
})
require.NoError(t, err)
require.NotEmpty(t, issued.Code)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
require.Empty(t, consumed.CompletionCodeHash)
require.Nil(t, consumed.CompletionCodeExpiresAt)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
}
func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expired",
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expired",
TTL: time.Second,
})
require.NoError(t, err)
_, err = client.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-open").
SetProviderSubject("union-adoption").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption",
},
})
require.NoError(t, err)
first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
require.NoError(t, err)
require.True(t, first.AdoptDisplayName)
require.False(t, first.AdoptAvatar)
require.Nil(t, first.IdentityID)
second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
})
require.NoError(t, err)
require.Equal(t, first.ID, second.ID)
require.NotNil(t, second.IdentityID)
require.Equal(t, identity.ID, *second.IdentityID)
require.True(t, second.AdoptAvatar)
}
func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "subject-session-token",
},
BrowserSessionKey: "browser-session",
LocalFlowState: map[string]any{
"completion_response": map[string]any{
"access_token": "token",
},
},
})
require.NoError(t, err)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
...@@ -106,6 +107,13 @@ func NewAuthService( ...@@ -106,6 +107,13 @@ func NewAuthService(
} }
} }
func (s *AuthService) EntClient() *dbent.Client {
if s == nil {
return nil
}
return s.entClient
}
// Register 用户注册,返回token和用户 // Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "") return s.RegisterWithVerification(ctx, email, password, "", "", "")
...@@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw ...@@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable return "", nil, ErrServiceUnavailable
} }
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码) // 标记邀请码为已使用(如果使用了邀请码)
...@@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string ...@@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() { if !user.IsActive() {
return "", nil, ErrUserNotActive return "", nil, ErrUserNotActive
} }
s.touchUserLogin(ctx, user.ID)
// 生成JWT token // 生成JWT token
token, err := s.GenerateToken(user) token, err := s.GenerateToken(user)
...@@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
} }
} else { } else {
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignDefaultSubscriptions(ctx, user.ID)
} }
} else { } else {
...@@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username ...@@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
} }
} }
s.touchUserLogin(ctx, user.ID)
token, err := s.GenerateToken(user) token, err := s.GenerateToken(user)
if err != nil { if err != nil {
...@@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable return nil, nil, ErrServiceUnavailable
} }
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignDefaultSubscriptions(ctx, user.ID)
} }
} else { } else {
...@@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
} }
} else { } else {
user = newUser user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID) s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil { if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
...@@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
} }
} }
s.touchUserLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "") tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil { if err != nil {
...@@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema ...@@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil return tokenPair, user, nil
} }
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
// while waiting for the user to supply an invitation code.
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
now := time.Now()
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.cfg.JWT.Secret))
}
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
// Returns ErrInvalidToken when the token is invalid or expired.
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
if len(tokenStr) > maxTokenLength {
return "", "", ErrInvalidToken
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if parseErr != nil {
return "", "", ErrInvalidToken
}
claims, ok := token.Claims.(*pendingOAuthClaims)
if !ok || !token.Valid {
return "", "", ErrInvalidToken
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return return
...@@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int ...@@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
} }
} }
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
signupSource = "email"
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
if signupSource == "email" {
s.ensureEmailAuthIdentity(ctx, user)
}
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
}
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
return
}
if err := s.entClient.User.UpdateOneID(userID).
SetSignupSource(signupSource).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
}
}
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
now := time.Now().UTC()
if err := s.entClient.User.UpdateOneID(userID).
SetLastLoginAt(now).
SetLastActiveAt(now).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
}
}
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return
}
if err := s.entClient.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(email).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{
"source": "auth_service_dual_write",
}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
}
}
func inferLegacySignupSource(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
switch {
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
return "linuxdo"
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
return "oidc"
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
return "wechat"
default:
return "email"
}
}
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil { if s.settingService == nil {
return nil return nil
...@@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) { ...@@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool { func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email)) normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
} }
// GenerateToken 生成JWT access token // GenerateToken 生成JWT access token
......
//go:build unit
package service_test
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type authIdentitySettingRepoStub struct {
values map[string]string
}
func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := repository.NewUserRepository(client, db)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-auth-identity-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
},
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
return svc, repo, client
}
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
svc, _, client := newAuthServiceWithEnt(t)
ctx := context.Background()
token, user, err := svc.Register(ctx, "user@example.com", "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, user)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "email", storedUser.SignupSource)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("user@example.com"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
require.NotNil(t, identity.VerifiedAt)
}
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
svc, repo, client := newAuthServiceWithEnt(t)
ctx := context.Background()
user := &service.User{
Email: "login@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 1,
Concurrency: 1,
}
require.NoError(t, user.SetPassword("password"))
require.NoError(t, repo.Create(ctx, user))
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
_, err := client.User.UpdateOneID(user.ID).
SetLastLoginAt(old).
SetLastActiveAt(old).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
require.True(t, storedUser.LastLoginAt.After(old))
require.True(t, storedUser.LastActiveAt.After(old))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment