Unverified Commit 8eb3f9e7 authored by Wesley Liddick's avatar Wesley Liddick Committed by GitHub
Browse files

Merge pull request #1785 from IanShaw027/rebuild/auth-identity-foundation

feat(auth,payment): 重构认证身份和支付系统及其他部分优化
parents 78f691d2 7fbd5177
......@@ -64,12 +64,70 @@ func RegisterAuthRoutes(
}), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart)
auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback)
auth.GET("/oauth/wechat/payment/start", h.Auth.WeChatPaymentOAuthStart)
auth.GET("/oauth/wechat/payment/callback", h.Auth.WeChatPaymentOAuthCallback)
auth.POST("/oauth/pending/exchange",
rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.ExchangePendingOAuthCompletion,
)
auth.POST("/oauth/pending/send-verify-code",
rateLimiter.LimitWithOptions("oauth-pending-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.SendPendingOAuthVerifyCode,
)
auth.POST("/oauth/pending/create-account",
rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreatePendingOAuthAccount,
)
auth.POST("/oauth/pending/bind-login",
rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindPendingOAuthLogin,
)
auth.POST("/oauth/linuxdo/complete-registration",
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteLinuxDoOAuthRegistration,
)
auth.POST("/oauth/linuxdo/bind-login",
rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindLinuxDoOAuthLogin,
)
auth.POST("/oauth/linuxdo/create-account",
rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreateLinuxDoOAuthAccount,
)
auth.POST("/oauth/wechat/complete-registration",
rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CompleteWeChatOAuthRegistration,
)
auth.POST("/oauth/wechat/bind-login",
rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindWeChatOAuthLogin,
)
auth.POST("/oauth/wechat/create-account",
rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreateWeChatOAuthAccount,
)
auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart)
auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback)
auth.POST("/oauth/oidc/complete-registration",
......@@ -78,6 +136,18 @@ func RegisterAuthRoutes(
}),
h.Auth.CompleteOIDCOAuthRegistration,
)
auth.POST("/oauth/oidc/bind-login",
rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.BindOIDCOAuthLogin,
)
auth.POST("/oauth/oidc/create-account",
rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}),
h.Auth.CreateOIDCOAuthAccount,
)
}
// 公开设置(无需认证)
......@@ -94,5 +164,23 @@ func RegisterAuthRoutes(
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
authenticated.GET("/auth/oauth/linuxdo/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.LinuxDoOAuthStart(c)
})
authenticated.GET("/auth/oauth/oidc/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.OIDCOAuthStart(c)
})
authenticated.GET("/auth/oauth/wechat/bind/start", func(c *gin.Context) {
query := c.Request.URL.Query()
query.Set("intent", "bind_current_user")
c.Request.URL.RawQuery = query.Encode()
h.Auth.WeChatOAuthStart(c)
})
}
}
......@@ -52,6 +52,7 @@ func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
"/api/v1/auth/login",
"/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code",
"/api/v1/auth/oauth/pending/send-verify-code",
}
for _, path := range paths {
......
......@@ -44,11 +44,13 @@ func RegisterPaymentRoutes(
}
// --- Public payment endpoints (no auth) ---
// Payment result page needs to verify order status without login
// (user session may have expired during provider redirect).
// Signed resume-token recovery is the supported public lookup path.
// The legacy anonymous out_trade_no verify endpoint is kept only as a
// compatibility shim that returns HTTP 410 Gone.
public := v1.Group("/payment/public")
{
public.POST("/orders/verify", paymentHandler.VerifyOrderPublic)
public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken)
}
// --- Webhook endpoints (no auth) ---
......
......@@ -25,6 +25,10 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
user.POST("/account-bindings/email/send-code", h.User.SendEmailBindingCode)
user.POST("/account-bindings/email", h.User.BindEmailIdentity)
user.DELETE("/account-bindings/:provider", h.User.UnbindIdentity)
user.POST("/auth-identities/bind/start", h.User.StartIdentityBinding)
// 通知邮箱管理
notifyEmail := user.Group("/notify-email")
......
......@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
......@@ -11,6 +12,8 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
......@@ -33,6 +36,7 @@ type AdminService interface {
// codeType is optional - pass empty string to return all types.
// Also returns totalRecharged (sum of all positive balance top-ups).
GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error)
BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error)
......@@ -127,6 +131,44 @@ type UpdateUserInput struct {
GroupRates map[int64]*float64
}
type AdminBindAuthIdentityInput struct {
ProviderType string
ProviderKey string
ProviderSubject string
Issuer *string
Metadata map[string]any
Channel *AdminBindAuthIdentityChannelInput
}
type AdminBindAuthIdentityChannelInput struct {
Channel string
ChannelAppID string
ChannelSubject string
Metadata map[string]any
}
type AdminBoundAuthIdentity struct {
UserID int64 `json:"user_id"`
ProviderType string `json:"provider_type"`
ProviderKey string `json:"provider_key"`
ProviderSubject string `json:"provider_subject"`
VerifiedAt *time.Time `json:"verified_at,omitempty"`
Issuer *string `json:"issuer,omitempty"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"`
}
type AdminBoundAuthIdentityChannel struct {
Channel string `json:"channel"`
ChannelAppID string `json:"channel_app_id"`
ChannelSubject string `json:"channel_subject"`
Metadata map[string]any `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type CreateGroupInput struct {
Name string
Description string
......@@ -491,6 +533,20 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
if len(users) > 0 {
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs)
if latestErr != nil {
logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr)
} else {
for i := range users {
users[i].LastUsedAt = lastUsedByUserID[users[i].ID]
}
}
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok {
......@@ -535,6 +591,12 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error)
if err != nil {
return nil, err
}
lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id)
if latestErr != nil {
logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr)
} else {
user.LastUsedAt = lastUsedAt
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
......@@ -797,6 +859,227 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int
return codes, result.Total, totalRecharged, nil
}
func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) {
if userID <= 0 {
return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0")
}
if s == nil || s.entClient == nil || s.userRepo == nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable")
}
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
return nil, err
}
providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType)
providerKey := strings.TrimSpace(input.ProviderKey)
providerSubject := strings.TrimSpace(input.ProviderSubject)
if providerType == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
}
if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
}
var issuer *string
if input.Issuer != nil {
trimmed := strings.TrimSpace(*input.Issuer)
if trimmed != "" {
issuer = &trimmed
}
}
channelInput := normalizeAdminBindChannelInput(input.Channel)
if input.Channel != nil && channelInput == nil {
return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided")
}
verifiedAt := time.Now().UTC()
tx, err := s.entClient.Tx(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err)
}
defer func() { _ = tx.Rollback() }()
identity, err := tx.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyEQ(providerKey),
authidentity.ProviderSubjectEQ(providerSubject),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
if identity != nil && identity.UserID != userID {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
if identity == nil {
create := tx.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderSubject(providerSubject).
SetVerifiedAt(verifiedAt)
if issuer != nil {
create = create.SetIssuer(*issuer)
}
if input.Metadata != nil {
create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
}
identity, err = create.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
}
} else {
update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt)
if issuer != nil {
update = update.SetIssuer(*issuer)
}
if input.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata))
}
identity, err = update.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err)
}
}
var channel *dbent.AuthIdentityChannel
if channelInput != nil {
channel, err = tx.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(providerType),
authidentitychannel.ProviderKeyEQ(providerKey),
authidentitychannel.ChannelEQ(channelInput.Channel),
authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID),
authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject),
).
WithIdentity().
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err)
}
if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID {
return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
}
if channel == nil {
create := tx.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetChannel(channelInput.Channel).
SetChannelAppID(channelInput.ChannelAppID).
SetChannelSubject(channelInput.ChannelSubject)
if channelInput.Metadata != nil {
create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
}
channel, err = create.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
}
} else {
update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID)
if channelInput.Metadata != nil {
update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata))
}
channel, err = update.Save(ctx)
if err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err)
}
}
}
if err := tx.Commit(); err != nil {
return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err)
}
return buildAdminBoundAuthIdentity(identity, channel), nil
}
func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput {
if input == nil {
return nil
}
channel := &AdminBindAuthIdentityChannelInput{
Channel: strings.TrimSpace(input.Channel),
ChannelAppID: strings.TrimSpace(input.ChannelAppID),
ChannelSubject: strings.TrimSpace(input.ChannelSubject),
Metadata: cloneAdminAuthIdentityMetadata(input.Metadata),
}
if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" {
return nil
}
return channel
}
func normalizeAdminAuthIdentityProviderType(input string) string {
switch strings.ToLower(strings.TrimSpace(input)) {
case "email":
return "email"
case "linuxdo":
return "linuxdo"
case "oidc":
return "oidc"
case "wechat":
return "wechat"
default:
return ""
}
}
func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity {
if identity == nil {
return nil
}
result := &AdminBoundAuthIdentity{
UserID: identity.UserID,
ProviderType: strings.TrimSpace(identity.ProviderType),
ProviderKey: strings.TrimSpace(identity.ProviderKey),
ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
VerifiedAt: identity.VerifiedAt,
Issuer: identity.Issuer,
Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata),
CreatedAt: identity.CreatedAt,
UpdatedAt: identity.UpdatedAt,
}
if channel != nil {
result.Channel = &AdminBoundAuthIdentityChannel{
Channel: strings.TrimSpace(channel.Channel),
ChannelAppID: strings.TrimSpace(channel.ChannelAppID),
ChannelSubject: strings.TrimSpace(channel.ChannelSubject),
Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata),
CreatedAt: channel.CreatedAt,
UpdatedAt: channel.UpdatedAt,
}
}
return result
}
func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any {
if input == nil {
return nil
}
if len(input) == 0 {
return map[string]any{}
}
data, err := json.Marshal(input)
if err != nil {
out := make(map[string]any, len(input))
for key, value := range input {
out[key] = value
}
return out
}
var out map[string]any
if err := json.Unmarshal(data, &out); err != nil {
out = make(map[string]any, len(input))
for key, value := range input {
out[key] = value
}
}
return out
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder}
......
......@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
......@@ -70,6 +79,23 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UnbindUserAuthProvider(context.Context, int64, string) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
......
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/enttest"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client {
t.Helper()
db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("bind-target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "union-123",
Metadata: map[string]any{"source": "admin-repair"},
Channel: &AdminBindAuthIdentityChannelInput{
Channel: "open",
ChannelAppID: "wx-open",
ChannelSubject: "openid-123",
Metadata: map[string]any{"scene": "migration"},
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, user.ID, result.UserID)
require.Equal(t, "wechat", result.ProviderType)
require.Equal(t, "wechat-main", result.ProviderKey)
require.NotNil(t, result.VerifiedAt)
require.NotNil(t, result.Channel)
require.Equal(t, "open", result.Channel.Channel)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("wechat"),
authidentity.ProviderKeyEQ("wechat-main"),
authidentity.ProviderSubjectEQ("union-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
require.Equal(t, "admin-repair", identity.Metadata["source"])
require.NotNil(t, identity.VerifiedAt)
channel, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ("wechat"),
authidentitychannel.ProviderKeyEQ("wechat-main"),
authidentitychannel.ChannelEQ("open"),
authidentitychannel.ChannelAppIDEQ("wx-open"),
authidentitychannel.ChannelSubjectEQ("openid-123"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, identity.ID, channel.IdentityID)
require.Equal(t, "migration", channel.Metadata["scene"])
}
func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
owner, err := client.User.Create().
SetEmail("owner@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
target, err := client.User.Create().
SetEmail("target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(owner.ID).
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("subject-1").
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}},
entClient: client,
}
_, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-1",
})
require.Error(t, err)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err))
}
func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("same-user@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-2",
Metadata: map[string]any{"source": "first"},
})
require.NoError(t, err)
second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-2",
Metadata: map[string]any{"source": "second"},
})
require.NoError(t, err)
require.Equal(t, first.UserID, second.UserID)
require.Equal(t, "second", second.Metadata["source"])
identities, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example"),
authidentity.ProviderSubjectEQ("subject-2"),
).
All(ctx)
require.NoError(t, err)
require.Len(t, identities, 1)
require.Equal(t, "second", identities[0].Metadata["source"])
}
func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) {
client := newAdminServiceAuthIdentityBindingTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("invalid-provider@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
svc := &adminServiceImpl{
userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}},
entClient: client,
}
_, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{
ProviderType: "github",
ProviderKey: "github-main",
ProviderSubject: "subject-3",
})
require.Error(t, err)
require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err))
}
......@@ -13,15 +13,18 @@ import (
)
type userRepoStub struct {
user *User
getErr error
createErr error
deleteErr error
exists bool
existsErr error
nextID int64
created []*User
deletedIDs []int64
user *User
getErr error
createErr error
deleteErr error
exists bool
existsErr error
nextID int64
created []*User
updated []*User
deletedIDs []int64
usersByEmail map[string]*User
getByEmailErr error
}
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
......@@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error {
user.ID = s.nextID
}
s.created = append(s.created, user)
if s.usersByEmail == nil {
s.usersByEmail = make(map[string]*User)
}
s.usersByEmail[user.Email] = user
s.user = user
return nil
}
......@@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
}
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
panic("unexpected GetByEmail call")
if s.getByEmailErr != nil {
return nil, s.getByEmailErr
}
if s.usersByEmail != nil {
if user, ok := s.usersByEmail[email]; ok {
return user, nil
}
}
if s.user != nil && s.user.Email == email {
return s.user, nil
}
return nil, ErrUserNotFound
}
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
......@@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
}
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
panic("unexpected Update call")
s.updated = append(s.updated, user)
if s.usersByEmail == nil {
s.usersByEmail = make(map[string]*User)
}
s.usersByEmail[user.Email] = user
s.user = user
return nil
}
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
......@@ -62,6 +87,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
panic("unexpected GetUserAvatar call")
}
func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
......@@ -70,6 +107,18 @@ func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.Pa
panic("unexpected ListWithFilters call")
}
func (s *userRepoStub) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserIDs call")
}
func (s *userRepoStub) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) {
panic("unexpected GetLatestUsedAtByUserID call")
}
func (s *userRepoStub) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error {
panic("unexpected UpdateUserLastActiveAt call")
}
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
......@@ -101,6 +150,14 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) {
panic("unexpected ListUserAuthIdentities call")
}
func (s *userRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error {
panic("unexpected UnbindUserAuthProvider call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
......
//go:build unit
package service
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type ensureEmailCall struct {
userID int64
email string
}
type replaceEmailCall struct {
userID int64
oldEmail string
newEmail string
}
type emailSyncRepoStub struct {
user *User
nextID int64
updateCalls int
created []*User
updated []*User
ensureCalls []ensureEmailCall
replaceCalls []replaceEmailCall
ensureErr error
replaceErr error
}
func (s *emailSyncRepoStub) Create(_ context.Context, user *User) error {
if s.nextID != 0 && user.ID == 0 {
user.ID = s.nextID
}
s.created = append(s.created, user)
s.user = user
return nil
}
func (s *emailSyncRepoStub) GetByID(_ context.Context, _ int64) (*User, error) {
if s.user == nil {
return nil, ErrUserNotFound
}
cloned := *s.user
return &cloned, nil
}
func (s *emailSyncRepoStub) GetByEmail(_ context.Context, _ string) (*User, error) {
return nil, ErrUserNotFound
}
func (s *emailSyncRepoStub) GetFirstAdmin(context.Context) (*User, error) {
return nil, fmt.Errorf("unexpected GetFirstAdmin call")
}
func (s *emailSyncRepoStub) Update(_ context.Context, user *User) error {
s.updateCalls++
s.updated = append(s.updated, user)
s.user = user
return nil
}
func (s *emailSyncRepoStub) Delete(context.Context, int64) error { return nil }
func (s *emailSyncRepoStub) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
return nil, fmt.Errorf("unexpected GetUserAvatar call")
}
func (s *emailSyncRepoStub) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
return nil, fmt.Errorf("unexpected UpsertUserAvatar call")
}
func (s *emailSyncRepoStub) DeleteUserAvatar(context.Context, int64) error {
return fmt.Errorf("unexpected DeleteUserAvatar call")
}
func (s *emailSyncRepoStub) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, fmt.Errorf("unexpected List call")
}
func (s *emailSyncRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
return nil, nil, fmt.Errorf("unexpected ListWithFilters call")
}
func (s *emailSyncRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) {
return map[int64]*time.Time{}, nil
}
func (s *emailSyncRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) {
return nil, nil
}
func (s *emailSyncRepoStub) UpdateUserLastActiveAt(context.Context, int64, time.Time) error {
return nil
}
func (s *emailSyncRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil }
func (s *emailSyncRepoStub) DeductBalance(context.Context, int64, float64) error { return nil }
func (s *emailSyncRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil }
func (s *emailSyncRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
return nil
}
func (s *emailSyncRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) {
return nil, nil
}
func (s *emailSyncRepoStub) UnbindUserAuthProvider(context.Context, int64, string) error { return nil }
func (s *emailSyncRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (s *emailSyncRepoStub) EnableTotp(context.Context, int64) error { return nil }
func (s *emailSyncRepoStub) DisableTotp(context.Context, int64) error { return nil }
func (s *emailSyncRepoStub) EnsureEmailAuthIdentity(_ context.Context, userID int64, email string) error {
s.ensureCalls = append(s.ensureCalls, ensureEmailCall{userID: userID, email: email})
return s.ensureErr
}
func (s *emailSyncRepoStub) ReplaceEmailAuthIdentity(_ context.Context, userID int64, oldEmail, newEmail string) error {
s.replaceCalls = append(s.replaceCalls, replaceEmailCall{
userID: userID,
oldEmail: oldEmail,
newEmail: newEmail,
})
return s.replaceErr
}
func TestAdminService_CreateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
nextID: 55,
ensureErr: fmt.Errorf("unexpected email resync"),
}
svc := &adminServiceImpl{userRepo: repo}
user, err := svc.CreateUser(context.Background(), &CreateUserInput{
Email: "admin-created@example.com",
Password: "strong-pass",
})
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, int64(55), user.ID)
require.Empty(t, repo.ensureCalls)
require.Empty(t, repo.replaceCalls)
}
func TestAdminService_UpdateUser_DoesNotReturnPartialSuccessFromEmailIdentityResync(t *testing.T) {
repo := &emailSyncRepoStub{
user: &User{
ID: 91,
Email: "before@example.com",
Role: RoleUser,
Status: StatusActive,
Concurrency: 3,
},
replaceErr: fmt.Errorf("unexpected email resync"),
}
svc := &adminServiceImpl{userRepo: repo}
updated, err := svc.UpdateUser(context.Background(), 91, &UpdateUserInput{
Email: "after@example.com",
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, "after@example.com", updated.Email)
require.Empty(t, repo.replaceCalls)
require.Empty(t, repo.ensureCalls)
}
......@@ -6,6 +6,7 @@ import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
......@@ -16,6 +17,8 @@ type userRepoStubForListUsers struct {
users []User
err error
listWithFiltersParams pagination.PaginationParams
lastUsedByUserID map[int64]*time.Time
lastUsedErr error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
......@@ -32,6 +35,26 @@ func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pag
}, nil
}
func (s *userRepoStubForListUsers) GetLatestUsedAtByUserIDs(_ context.Context, userIDs []int64) (map[int64]*time.Time, error) {
if s.lastUsedErr != nil {
return nil, s.lastUsedErr
}
result := make(map[int64]*time.Time, len(userIDs))
for _, userID := range userIDs {
if ts, ok := s.lastUsedByUserID[userID]; ok {
result[userID] = ts
}
}
return result, nil
}
func (s *userRepoStubForListUsers) GetLatestUsedAtByUserID(_ context.Context, userID int64) (*time.Time, error) {
if s.lastUsedErr != nil {
return nil, s.lastUsedErr
}
return s.lastUsedByUserID[userID], nil
}
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
......@@ -130,3 +153,21 @@ func TestAdminService_ListUsers_PassesSortParams(t *testing.T) {
SortOrder: "ASC",
}, userRepo.listWithFiltersParams)
}
func TestAdminService_ListUsers_PopulatesLastUsedAt(t *testing.T) {
lastUsed := time.Now().UTC().Add(-30 * time.Minute).Truncate(time.Second)
userRepo := &userRepoStubForListUsers{
users: []User{{ID: 101, Email: "u@example.com"}},
lastUsedByUserID: map[int64]*time.Time{
101: &lastUsed,
},
}
svc := &adminServiceImpl{userRepo: userRepo}
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "")
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Len(t, users, 1)
require.NotNil(t, users[0].LastUsedAt)
require.WithinDuration(t, lastUsed, *users[0].LastUsedAt, time.Second)
}
......@@ -5,6 +5,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
......@@ -34,8 +35,23 @@ const (
)
var (
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
ErrAnnouncementNilInput = infraerrors.BadRequest("ANNOUNCEMENT_INPUT_REQUIRED", "announcement input is required")
ErrAnnouncementInvalidTitle = infraerrors.BadRequest("ANNOUNCEMENT_TITLE_INVALID", "announcement title is invalid")
ErrAnnouncementContentRequired = infraerrors.BadRequest(
"ANNOUNCEMENT_CONTENT_REQUIRED",
"announcement content is required",
)
ErrAnnouncementInvalidStatus = infraerrors.BadRequest("ANNOUNCEMENT_STATUS_INVALID", "announcement status is invalid")
ErrAnnouncementInvalidNotifyMode = infraerrors.BadRequest(
"ANNOUNCEMENT_NOTIFY_MODE_INVALID",
"announcement notify_mode is invalid",
)
ErrAnnouncementInvalidSchedule = infraerrors.BadRequest(
"ANNOUNCEMENT_TIME_RANGE_INVALID",
"starts_at must be before ends_at",
)
)
type AnnouncementTargeting = domain.AnnouncementTargeting
......
......@@ -70,16 +70,16 @@ type AnnouncementUserReadStatus struct {
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("create announcement: nil input")
return nil, ErrAnnouncementNilInput
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("create announcement: invalid title")
return nil, ErrAnnouncementInvalidTitle
}
if content == "" {
return nil, fmt.Errorf("create announcement: content is required")
return nil, ErrAnnouncementContentRequired
}
status := strings.TrimSpace(input.Status)
......@@ -87,7 +87,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("create announcement: invalid status")
return nil, ErrAnnouncementInvalidStatus
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
......@@ -100,12 +100,12 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
notifyMode = AnnouncementNotifyModeSilent
}
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("create announcement: invalid notify_mode")
return nil, ErrAnnouncementInvalidNotifyMode
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
return nil, ErrAnnouncementInvalidSchedule
}
}
......@@ -131,7 +131,7 @@ func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncem
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("update announcement: nil input")
return nil, ErrAnnouncementNilInput
}
a, err := s.announcementRepo.GetByID(ctx, id)
......@@ -142,21 +142,21 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("update announcement: invalid title")
return nil, ErrAnnouncementInvalidTitle
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
return nil, fmt.Errorf("update announcement: content is required")
return nil, ErrAnnouncementContentRequired
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("update announcement: invalid status")
return nil, ErrAnnouncementInvalidStatus
}
a.Status = status
}
......@@ -164,7 +164,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if input.NotifyMode != nil {
notifyMode := strings.TrimSpace(*input.NotifyMode)
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("update announcement: invalid notify_mode")
return nil, ErrAnnouncementInvalidNotifyMode
}
a.NotifyMode = notifyMode
}
......@@ -186,7 +186,7 @@ func (s *AnnouncementService) Update(ctx context.Context, id int64, input *Updat
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
return nil, ErrAnnouncementInvalidSchedule
}
}
......
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type announcementRepoStub struct {
item *Announcement
}
func (s *announcementRepoStub) Create(_ context.Context, a *Announcement) error {
s.item = a
return nil
}
func (s *announcementRepoStub) GetByID(_ context.Context, _ int64) (*Announcement, error) {
if s.item == nil {
return nil, ErrAnnouncementNotFound
}
return s.item, nil
}
func (s *announcementRepoStub) Update(_ context.Context, a *Announcement) error {
s.item = a
return nil
}
func (*announcementRepoStub) Delete(context.Context, int64) error {
return nil
}
func (*announcementRepoStub) List(context.Context, pagination.PaginationParams, AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (*announcementRepoStub) ListActive(context.Context, time.Time) ([]Announcement, error) {
return nil, nil
}
func TestAnnouncementServiceCreateRejectsEqualStartEndTimes(t *testing.T) {
repo := &announcementRepoStub{}
svc := NewAnnouncementService(repo, nil, nil, nil)
now := time.Unix(1776790020, 0)
_, err := svc.Create(context.Background(), &CreateAnnouncementInput{
Title: "公告",
Content: "内容",
Status: AnnouncementStatusActive,
NotifyMode: AnnouncementNotifyModePopup,
StartsAt: &now,
EndsAt: &now,
})
require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
}
func TestAnnouncementServiceUpdateRejectsEqualStartEndTimes(t *testing.T) {
repo := &announcementRepoStub{
item: &Announcement{
ID: 1,
Title: "公告",
Content: "内容",
Status: AnnouncementStatusActive,
NotifyMode: AnnouncementNotifyModePopup,
},
}
svc := NewAnnouncementService(repo, nil, nil, nil)
now := time.Unix(1776790020, 0)
startsAt := &now
endsAt := &now
_, err := svc.Update(context.Background(), 1, &UpdateAnnouncementInput{
StartsAt: &startsAt,
EndsAt: &endsAt,
})
require.ErrorIs(t, err, ErrAnnouncementInvalidSchedule)
}
package service
import (
"context"
"errors"
"fmt"
"net/mail"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// BindEmailIdentity verifies and binds a local email/password identity to the
// current user, or replaces the existing bound primary email.
func (s *AuthService) BindEmailIdentity(
ctx context.Context,
userID int64,
email string,
verifyCode string,
password string,
) (*User, error) {
if s == nil {
return nil, ErrServiceUnavailable
}
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
if err != nil {
return nil, err
}
if isReservedEmail(normalizedEmail) {
return nil, ErrEmailReserved
}
if strings.TrimSpace(password) == "" {
return nil, ErrPasswordRequired
}
if err := s.VerifyOAuthEmailCode(ctx, normalizedEmail, verifyCode); err != nil {
return nil, err
}
currentUser, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
firstRealEmailBind := !hasBindableEmailIdentitySubject(currentUser.Email)
if firstRealEmailBind && len(password) < 6 {
return nil, infraerrors.BadRequest("PASSWORD_TOO_SHORT", "password must be at least 6 characters")
}
if !firstRealEmailBind && !s.CheckPassword(password, currentUser.PasswordHash) {
return nil, ErrPasswordIncorrect
}
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
case err == nil && existingUser != nil && existingUser.ID != userID:
return nil, ErrEmailExists
case err != nil && !errors.Is(err, ErrUserNotFound):
return nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
if s.entClient != nil {
if err := s.updateBoundEmailIdentityTx(ctx, currentUser, normalizedEmail, hashedPassword, firstRealEmailBind); err != nil {
return nil, err
}
return currentUser, nil
}
currentUser.Email = normalizedEmail
currentUser.PasswordHash = hashedPassword
if err := s.userRepo.Update(ctx, currentUser); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, ErrEmailExists
}
return nil, ErrServiceUnavailable
}
if firstRealEmailBind {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, userID, "email"); err != nil {
return nil, fmt.Errorf("apply email first bind defaults: %w", err)
}
}
return currentUser, nil
}
// SendEmailIdentityBindCode sends a verification code for authenticated email binding flows.
func (s *AuthService) SendEmailIdentityBindCode(ctx context.Context, userID int64, email string) error {
if s == nil {
return ErrServiceUnavailable
}
normalizedEmail, err := normalizeEmailForIdentityBinding(email)
if err != nil {
return err
}
if isReservedEmail(normalizedEmail) {
return ErrEmailReserved
}
if s.emailService == nil {
return ErrServiceUnavailable
}
if _, err := s.userRepo.GetByID(ctx, userID); err != nil {
if errors.Is(err, ErrUserNotFound) {
return ErrUserNotFound
}
return ErrServiceUnavailable
}
existingUser, err := s.userRepo.GetByEmail(ctx, normalizedEmail)
switch {
case err == nil && existingUser != nil && existingUser.ID != userID:
return ErrEmailExists
case err != nil && !errors.Is(err, ErrUserNotFound):
return ErrServiceUnavailable
}
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, normalizedEmail, siteName)
}
func normalizeEmailForIdentityBinding(email string) (string, error) {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" || len(normalized) > 255 {
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(normalized); err != nil {
return "", infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
return normalized, nil
}
func hasBindableEmailIdentitySubject(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return normalized != "" && !isReservedEmail(normalized)
}
func (s *AuthService) updateBoundEmailIdentityTx(
ctx context.Context,
currentUser *User,
email string,
hashedPassword string,
applyFirstBindDefaults bool,
) error {
if tx := dbent.TxFromContext(ctx); tx != nil {
return s.updateBoundEmailIdentityWithClient(ctx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults)
}
tx, err := s.entClient.Tx(ctx)
if err != nil {
return ErrServiceUnavailable
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.updateBoundEmailIdentityWithClient(txCtx, tx.Client(), currentUser, email, hashedPassword, applyFirstBindDefaults); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return ErrServiceUnavailable
}
return nil
}
func (s *AuthService) updateBoundEmailIdentityWithClient(
ctx context.Context,
client *dbent.Client,
currentUser *User,
email string,
hashedPassword string,
applyFirstBindDefaults bool,
) error {
if client == nil || currentUser == nil || currentUser.ID <= 0 {
return ErrServiceUnavailable
}
oldEmail := currentUser.Email
if _, err := client.User.UpdateOneID(currentUser.ID).
SetEmail(email).
SetPasswordHash(hashedPassword).
Save(ctx); err != nil {
if dbent.IsConstraintError(err) {
return ErrEmailExists
}
return ErrServiceUnavailable
}
if err := replaceBoundEmailAuthIdentityWithClient(ctx, client, currentUser.ID, oldEmail, email, "auth_service_email_bind"); err != nil {
if errors.Is(err, ErrEmailExists) {
return ErrEmailExists
}
return ErrServiceUnavailable
}
if applyFirstBindDefaults {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, currentUser.ID, "email"); err != nil {
return fmt.Errorf("apply email first bind defaults: %w", err)
}
}
updatedUser, err := client.User.Get(ctx, currentUser.ID)
if err != nil {
return ErrServiceUnavailable
}
currentUser.Email = updatedUser.Email
currentUser.PasswordHash = updatedUser.PasswordHash
currentUser.Balance = updatedUser.Balance
currentUser.Concurrency = updatedUser.Concurrency
currentUser.UpdatedAt = updatedUser.UpdatedAt
return nil
}
func replaceBoundEmailAuthIdentityWithClient(
ctx context.Context,
client *dbent.Client,
userID int64,
oldEmail string,
newEmail string,
source string,
) error {
newSubject := normalizeBoundEmailAuthIdentitySubject(newEmail)
if err := ensureBoundEmailAuthIdentityWithClient(ctx, client, userID, newSubject, source); err != nil {
return err
}
oldSubject := normalizeBoundEmailAuthIdentitySubject(oldEmail)
if oldSubject == "" || oldSubject == newSubject {
return nil
}
_, err := client.AuthIdentity.Delete().
Where(
authidentity.UserIDEQ(userID),
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(oldSubject),
).
Exec(ctx)
return err
}
func ensureBoundEmailAuthIdentityWithClient(
ctx context.Context,
client *dbent.Client,
userID int64,
subject string,
source string,
) error {
if client == nil || userID <= 0 || subject == "" {
return nil
}
if strings.TrimSpace(source) == "" {
source = "auth_service_email_bind"
}
if err := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(subject).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{"source": strings.TrimSpace(source)}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
if !isSQLNoRowsError(err) {
return err
}
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(subject),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity.UserID != userID {
return ErrEmailExists
}
return nil
}
func normalizeBoundEmailAuthIdentitySubject(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" || isReservedEmail(normalized) {
return ""
}
return normalized
}
package service
import (
"context"
"errors"
"fmt"
"net/mail"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
)
func normalizeOAuthSignupSource(signupSource string) string {
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
return "email"
}
return signupSource
}
// SendPendingOAuthVerifyCode sends a local verification code for pending OAuth
// account-creation flows without relying on the public registration gate.
func (s *AuthService) SendPendingOAuthVerifyCode(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
email = strings.TrimSpace(strings.ToLower(email))
if email == "" {
return nil, ErrEmailVerifyRequired
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, ErrEmailVerifyRequired
}
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
if s == nil || s.emailService == nil {
return nil, ErrServiceUnavailable
}
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
if err := s.emailService.SendVerifyCode(ctx, email, siteName); err != nil {
return nil, err
}
return &SendVerifyCodeResult{
Countdown: int(verifyCodeCooldown / time.Second),
}, nil
}
func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil, nil
}
if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
return nil, ErrServiceUnavailable
}
invitationCode = strings.TrimSpace(invitationCode)
if invitationCode == "" {
return nil, ErrInvitationCodeRequired
}
redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
return nil, ErrInvitationCodeInvalid
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
return nil, ErrInvitationCodeInvalid
}
return redeemCode, nil
}
// VerifyOAuthEmailCode verifies the locally entered email verification code for
// third-party signup and binding flows. This is intentionally independent from
// the global registration email verification toggle.
func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error {
email = strings.TrimSpace(strings.ToLower(email))
verifyCode = strings.TrimSpace(verifyCode)
if email == "" {
return ErrEmailVerifyRequired
}
if verifyCode == "" {
return ErrEmailVerifyRequired
}
if s == nil || s.emailService == nil {
return ErrServiceUnavailable
}
return s.emailService.VerifyCode(ctx, email, verifyCode)
}
// RegisterOAuthEmailAccount creates a local account from a third-party first
// login after the user has verified a local email address.
func (s *AuthService) RegisterOAuthEmailAccount(
ctx context.Context,
email string,
password string,
verifyCode string,
invitationCode string,
signupSource string,
) (*TokenPair, *User, error) {
if s == nil {
return nil, nil, ErrServiceUnavailable
}
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
email = strings.TrimSpace(strings.ToLower(email))
if isReservedEmail(email) {
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
return nil, nil, err
}
if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
return nil, nil, err
}
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return nil, nil, ErrServiceUnavailable
}
if existsEmail {
return nil, nil, ErrEmailExists
}
hashedPassword, err := s.HashPassword(password)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
signupSource = strings.TrimSpace(strings.ToLower(signupSource))
if signupSource == "" {
signupSource = "email"
}
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
if errors.Is(err, ErrEmailExists) {
return nil, nil, ErrEmailExists
}
return nil, nil, ErrServiceUnavailable
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
_ = s.RollbackOAuthEmailAccountCreation(ctx, user.ID, "")
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// FinalizeOAuthEmailAccount applies invitation usage and normal signup bootstrap
// only after the pending OAuth flow has fully reached its last reversible step.
func (s *AuthService) FinalizeOAuthEmailAccount(
ctx context.Context,
user *User,
invitationCode string,
signupSource string,
) error {
if s == nil || user == nil || user.ID <= 0 {
return ErrServiceUnavailable
}
signupSource = normalizeOAuthSignupSource(signupSource)
invitationRedeemCode, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
return err
}
if invitationRedeemCode != nil {
if err := s.useOAuthRegistrationInvitation(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return ErrInvitationCodeInvalid
}
}
s.updateOAuthSignupSource(ctx, user.ID, signupSource)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
return nil
}
// RollbackOAuthEmailAccountCreation removes a partially-created local account
// and restores any invitation code already consumed by that account.
func (s *AuthService) RollbackOAuthEmailAccountCreation(ctx context.Context, userID int64, invitationCode string) error {
if s == nil || s.userRepo == nil || userID <= 0 {
return ErrServiceUnavailable
}
if err := s.restoreOAuthRegistrationInvitation(ctx, invitationCode, userID); err != nil {
return err
}
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete created oauth user: %w", err)
}
return nil
}
func (s *AuthService) restoreOAuthRegistrationInvitation(ctx context.Context, invitationCode string, userID int64) error {
if s == nil || s.settingService == nil || !s.settingService.IsInvitationCodeEnabled(ctx) {
return nil
}
if s.redeemRepo == nil && s.oauthEmailFlowClient(ctx) == nil {
return ErrServiceUnavailable
}
invitationCode = strings.TrimSpace(invitationCode)
if invitationCode == "" || userID <= 0 {
return nil
}
redeemCode, err := s.loadOAuthRegistrationInvitation(ctx, invitationCode)
if err != nil {
if errors.Is(err, ErrRedeemCodeNotFound) {
return nil
}
return fmt.Errorf("load invitation code: %w", err)
}
if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUsed || redeemCode.UsedBy == nil || *redeemCode.UsedBy != userID {
return nil
}
redeemCode.Status = StatusUnused
redeemCode.UsedBy = nil
redeemCode.UsedAt = nil
if err := s.updateOAuthRegistrationInvitation(ctx, redeemCode); err != nil {
return fmt.Errorf("restore invitation code: %w", err)
}
return nil
}
func (s *AuthService) oauthEmailFlowClient(ctx context.Context) *dbent.Client {
if s == nil || s.entClient == nil {
return nil
}
if tx := dbent.TxFromContext(ctx); tx != nil {
return tx.Client()
}
return s.entClient
}
func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invitationCode string) (*RedeemCode, error) {
if client := s.oauthEmailFlowClient(ctx); client != nil {
entity, err := client.RedeemCode.Query().Where(redeemcode.CodeEQ(invitationCode)).Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrRedeemCodeNotFound
}
return nil, err
}
return &RedeemCode{
ID: entity.ID,
Code: entity.Code,
Type: entity.Type,
Value: entity.Value,
Status: entity.Status,
UsedBy: entity.UsedBy,
UsedAt: entity.UsedAt,
Notes: oauthEmailFlowStringValue(entity.Notes),
CreatedAt: entity.CreatedAt,
GroupID: entity.GroupID,
ValidityDays: entity.ValidityDays,
}, nil
}
return s.redeemRepo.GetByCode(ctx, invitationCode)
}
func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
if client := s.oauthEmailFlowClient(ctx); client != nil {
affected, err := client.RedeemCode.Update().
Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
SetStatus(StatusUsed).
SetUsedBy(userID).
SetUsedAt(time.Now().UTC()).
Save(ctx)
if err != nil {
return err
}
if affected == 0 {
return ErrRedeemCodeUsed
}
return nil
}
return s.redeemRepo.Use(ctx, invitationID, userID)
}
func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
if client := s.oauthEmailFlowClient(ctx); client != nil {
update := client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetType(code.Type).
SetValue(code.Value).
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays)
if code.UsedBy != nil {
update = update.SetUsedBy(*code.UsedBy)
} else {
update = update.ClearUsedBy()
}
if code.UsedAt != nil {
update = update.SetUsedAt(*code.UsedAt)
} else {
update = update.ClearUsedAt()
}
if code.GroupID != nil {
update = update.SetGroupID(*code.GroupID)
} else {
update = update.ClearGroupID()
}
_, err := update.Save(ctx)
return err
}
return s.redeemRepo.Update(ctx, code)
}
func (s *AuthService) updateOAuthSignupSource(ctx context.Context, userID int64, signupSource string) {
client := s.oauthEmailFlowClient(ctx)
if client == nil || userID <= 0 || strings.TrimSpace(signupSource) == "" {
return
}
_ = client.User.UpdateOneID(userID).SetSignupSource(signupSource).Exec(ctx)
}
func oauthEmailFlowStringValue(value *string) string {
if value == nil {
return ""
}
return *value
}
// ValidatePasswordCredentials checks the local password without completing the
// login flow. This is used by pending third-party account adoption flows before
// the external identity has been bound.
func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) {
if s == nil {
return nil, ErrServiceUnavailable
}
user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email)))
if err != nil {
if errors.Is(err, ErrUserNotFound) {
return nil, ErrInvalidCredentials
}
return nil, ErrServiceUnavailable
}
if !user.IsActive() {
return nil, ErrUserNotActive
}
if !s.CheckPassword(password, user.PasswordHash) {
return nil, ErrInvalidCredentials
}
return user, nil
}
// RecordSuccessfulLogin updates last-login activity after a non-standard login
// flow finishes with a real session.
func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) {
if s != nil && s.userRepo != nil && userID > 0 {
user, err := s.userRepo.GetByID(ctx, userID)
if err == nil && user != nil && !isReservedEmail(user.Email) {
s.backfillEmailIdentityOnSuccessfulLogin(ctx, user)
}
}
s.touchUserLogin(ctx, userID)
}
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type redeemCodeRepoStub struct {
codesByCode map[string]*RedeemCode
useCalls []struct {
id int64
userID int64
}
updateCalls []*RedeemCode
}
func (s *redeemCodeRepoStub) Create(context.Context, *RedeemCode) error {
panic("unexpected Create call")
}
func (s *redeemCodeRepoStub) CreateBatch(context.Context, []RedeemCode) error {
panic("unexpected CreateBatch call")
}
func (s *redeemCodeRepoStub) GetByID(context.Context, int64) (*RedeemCode, error) {
panic("unexpected GetByID call")
}
func (s *redeemCodeRepoStub) GetByCode(_ context.Context, code string) (*RedeemCode, error) {
if s.codesByCode == nil {
return nil, ErrRedeemCodeNotFound
}
redeemCode, ok := s.codesByCode[code]
if !ok {
return nil, ErrRedeemCodeNotFound
}
cloned := *redeemCode
return &cloned, nil
}
func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
cloned := *code
s.updateCalls = append(s.updateCalls, &cloned)
if s.codesByCode == nil {
s.codesByCode = make(map[string]*RedeemCode)
}
s.codesByCode[cloned.Code] = &cloned
return nil
}
func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *redeemCodeRepoStub) Use(_ context.Context, id, userID int64) error {
for code, redeemCode := range s.codesByCode {
if redeemCode.ID != id {
continue
}
now := time.Now().UTC()
redeemCode.Status = StatusUsed
redeemCode.UsedBy = &userID
redeemCode.UsedAt = &now
s.codesByCode[code] = redeemCode
s.useCalls = append(s.useCalls, struct {
id int64
userID int64
}{id: id, userID: userID})
return nil
}
return ErrRedeemCodeNotFound
}
func (s *redeemCodeRepoStub) List(context.Context, pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *redeemCodeRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *redeemCodeRepoStub) ListByUser(context.Context, int64, int) ([]RedeemCode, error) {
panic("unexpected ListByUser call")
}
func (s *redeemCodeRepoStub) ListByUserPaginated(context.Context, int64, pagination.PaginationParams, string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (s *redeemCodeRepoStub) SumPositiveBalanceByUser(context.Context, int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
func newOAuthEmailFlowAuthService(
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
settings map[string]string,
emailCache EmailCache,
) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
emailService := NewEmailService(&settingRepoStub{values: settings}, emailCache)
return NewAuthService(
nil,
userRepo,
redeemRepo,
refreshTokenCache,
cfg,
settingService,
emailService,
nil,
nil,
nil,
nil,
)
}
func TestRegisterOAuthEmailAccountRollsBackCreatedUserWhenTokenPairGenerationFails(t *testing.T) {
userRepo := &userRepoStub{nextID: 42}
redeemRepo := &redeemCodeRepoStub{
codesByCode: map[string]*RedeemCode{
"INVITE123": {
ID: 7,
Code: "INVITE123",
Type: RedeemTypeInvitation,
Status: StatusUnused,
},
},
}
emailCache := &emailCacheStub{
data: &VerificationCodeData{
Code: "246810",
Attempts: 0,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(15 * time.Minute),
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
redeemRepo,
nil,
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyInvitationCodeEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
},
emailCache,
)
tokenPair, user, err := authService.RegisterOAuthEmailAccount(
context.Background(),
"fresh@example.com",
"secret-123",
"246810",
"INVITE123",
"oidc",
)
require.Nil(t, tokenPair)
require.Nil(t, user)
require.Error(t, err)
require.Contains(t, err.Error(), "generate token pair")
require.Equal(t, []int64{42}, userRepo.deletedIDs)
require.Len(t, userRepo.created, 1)
require.Empty(t, redeemRepo.useCalls)
require.Empty(t, redeemRepo.updateCalls)
}
func TestRollbackOAuthEmailAccountCreationRestoresInvitationUsage(t *testing.T) {
userRepo := &userRepoStub{}
redeemRepo := &redeemCodeRepoStub{
codesByCode: map[string]*RedeemCode{
"INVITE123": {
ID: 7,
Code: "INVITE123",
Type: RedeemTypeInvitation,
Status: StatusUsed,
UsedBy: func() *int64 {
v := int64(42)
return &v
}(),
UsedAt: func() *time.Time {
v := time.Now().UTC()
return &v
}(),
},
},
}
authService := newOAuthEmailFlowAuthService(
userRepo,
redeemRepo,
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyInvitationCodeEnabled: "true",
},
&emailCacheStub{},
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "INVITE123")
require.NoError(t, err)
require.Equal(t, []int64{42}, userRepo.deletedIDs)
require.Len(t, redeemRepo.updateCalls, 1)
require.Equal(t, StatusUnused, redeemRepo.updateCalls[0].Status)
require.Nil(t, redeemRepo.updateCalls[0].UsedBy)
require.Nil(t, redeemRepo.updateCalls[0].UsedAt)
}
func TestRollbackOAuthEmailAccountCreationPropagatesDeleteError(t *testing.T) {
userRepo := &userRepoStub{deleteErr: errors.New("delete failed")}
authService := newOAuthEmailFlowAuthService(
userRepo,
&redeemCodeRepoStub{},
&refreshTokenCacheStub{},
map[string]string{
SettingKeyRegistrationEnabled: "true",
},
&emailCacheStub{},
)
err := authService.RollbackOAuthEmailAccountCreation(context.Background(), 42, "")
require.Error(t, err)
require.Contains(t, err.Error(), "delete created oauth user")
}
package service
import (
"context"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
entsql "entgo.io/ent/dialect/sql"
)
// ApplyProviderDefaultSettingsOnFirstBind applies provider-specific bootstrap
// settings the first time a user binds a third-party identity. The grant is
// idempotent per user/provider pair.
func (s *AuthService) ApplyProviderDefaultSettingsOnFirstBind(
ctx context.Context,
userID int64,
providerType string,
) error {
if s == nil || s.entClient == nil || s.settingService == nil || userID <= 0 {
return nil
}
if dbent.TxFromContext(ctx) != nil {
return s.applyProviderDefaultSettingsOnFirstBind(ctx, userID, providerType)
}
tx, err := s.entClient.Tx(ctx)
if err != nil {
return fmt.Errorf("begin first bind defaults transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := s.applyProviderDefaultSettingsOnFirstBind(txCtx, userID, providerType); err != nil {
return err
}
return tx.Commit()
}
func (s *AuthService) applyProviderDefaultSettingsOnFirstBind(
ctx context.Context,
userID int64,
providerType string,
) error {
providerDefaults, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, providerType, true)
if err != nil {
return fmt.Errorf("load auth source defaults: %w", err)
}
if !enabled {
return nil
}
client := s.entClient
if tx := dbent.TxFromContext(ctx); tx != nil {
client = tx.Client()
}
var result entsql.Result
if err := client.Driver().Exec(
ctx,
`INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`,
[]any{userID, strings.TrimSpace(providerType), "first_bind"},
&result,
); err != nil {
return fmt.Errorf("record first bind provider grant: %w", err)
}
affected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("read first bind provider grant result: %w", err)
}
if affected == 0 {
return nil
}
if providerDefaults.Balance != 0 {
if err := client.User.UpdateOneID(userID).AddBalance(providerDefaults.Balance).Exec(ctx); err != nil {
return fmt.Errorf("apply first bind balance default: %w", err)
}
}
if providerDefaults.Concurrency != 0 {
if err := client.User.UpdateOneID(userID).AddConcurrency(providerDefaults.Concurrency).Exec(ctx); err != nil {
return fmt.Errorf("apply first bind concurrency default: %w", err)
}
}
if s.defaultSubAssigner != nil {
for _, item := range providerDefaults.Subscriptions {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by first bind defaults",
}); err != nil {
return fmt.Errorf("apply first bind subscription default: %w", err)
}
}
}
return nil
}
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
entsql "entgo.io/ent/dialect/sql"
)
var (
ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
)
const (
defaultPendingAuthTTL = 15 * time.Minute
defaultPendingAuthCompletionTTL = 5 * time.Minute
)
type PendingAuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type CreatePendingAuthSessionInput struct {
SessionToken string
Intent string
Identity PendingAuthIdentityKey
TargetUserID *int64
RedirectTo string
ResolvedEmail string
RegistrationPasswordHash string
BrowserSessionKey string
UpstreamIdentityClaims map[string]any
LocalFlowState map[string]any
ExpiresAt time.Time
}
type IssuePendingAuthCompletionCodeInput struct {
PendingAuthSessionID int64
BrowserSessionKey string
TTL time.Duration
}
type IssuePendingAuthCompletionCodeResult struct {
Code string
ExpiresAt time.Time
}
type PendingIdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID *int64
AdoptDisplayName bool
AdoptAvatar bool
}
type AuthPendingIdentityService struct {
entClient *dbent.Client
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient}
}
func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken := strings.TrimSpace(input.SessionToken)
if sessionToken == "" {
var err error
sessionToken, err = randomOpaqueToken(24)
if err != nil {
return nil, err
}
}
expiresAt := input.ExpiresAt.UTC()
if expiresAt.IsZero() {
expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
}
create := s.entClient.PendingAuthSession.Create().
SetSessionToken(sessionToken).
SetIntent(strings.TrimSpace(input.Intent)).
SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
SetExpiresAt(expiresAt)
if input.TargetUserID != nil {
create = create.SetTargetUserID(*input.TargetUserID)
}
return create.Save(ctx)
}
func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
code, err := randomOpaqueToken(24)
if err != nil {
return nil, err
}
ttl := input.TTL
if ttl <= 0 {
ttl = defaultPendingAuthCompletionTTL
}
expiresAt := time.Now().UTC().Add(ttl)
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeHash(hashPendingAuthCode(code)).
SetCompletionCodeExpiresAt(expiresAt)
if strings.TrimSpace(input.BrowserSessionKey) != "" {
update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
}
if _, err := update.Save(ctx); err != nil {
return nil, err
}
return &IssuePendingAuthCompletionCodeResult{
Code: code,
ExpiresAt: expiresAt,
}, nil
}
func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthCodeInvalid
}
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
}
func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
}
func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken = strings.TrimSpace(sessionToken)
if sessionToken == "" {
return nil, ErrPendingAuthSessionNotFound
}
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) consumeSession(
ctx context.Context,
session *dbent.PendingAuthSession,
browserSessionKey string,
expiredErr error,
consumedErr error,
) (*dbent.PendingAuthSession, error) {
if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
return nil, err
}
now := time.Now().UTC()
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetConsumedAt(now).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx)
if err != nil {
return nil, err
}
return updated, nil
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
if session == nil {
return ErrPendingAuthSessionNotFound
}
now := time.Now().UTC()
if session.ConsumedAt != nil {
return consumedErr
}
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
return expiredErr
}
if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
return expiredErr
}
if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
return ErrPendingAuthBrowserMismatch
}
return nil
}
func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
if input.IdentityID != nil && *input.IdentityID > 0 {
if _, err := s.entClient.IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(*input.IdentityID),
dbpredicate.IdentityAdoptionDecision(func(s *entsql.Selector) {
col := s.C(identityadoptiondecision.FieldPendingAuthSessionID)
s.Where(entsql.Or(
entsql.IsNull(col),
entsql.NEQ(col, input.PendingAuthSessionID),
))
}),
).
ClearIdentityID().
Save(ctx); err != nil {
return nil, err
}
}
existing, err := s.entClient.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
}
return update.Save(ctx)
}
func copyPendingMap(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func randomOpaqueToken(byteLen int) (string, error) {
if byteLen <= 0 {
byteLen = 16
}
buf := make([]byte, byteLen)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func hashPendingAuthCode(code string) string {
sum := sha256.Sum256([]byte(code))
return hex.EncodeToString(sum[:])
}
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return NewAuthPendingIdentityService(client), client
}
func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
targetUser, err := client.User.Create().
SetEmail("pending-target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-123",
},
TargetUserID: &targetUser.ID,
RedirectTo: "/profile",
ResolvedEmail: "user@example.com",
BrowserSessionKey: "browser-1",
UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
LocalFlowState: map[string]any{"step": "email_required"},
})
require.NoError(t, err)
require.NotEmpty(t, session.SessionToken)
require.Equal(t, "bind_current_user", session.Intent)
require.Equal(t, "wechat", session.ProviderType)
require.NotNil(t, session.TargetUserID)
require.Equal(t, targetUser.ID, *session.TargetUserID)
require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
require.Equal(t, "email_required", session.LocalFlowState["step"])
}
func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expected",
UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
LocalFlowState: map[string]any{"step": "pending"},
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expected",
})
require.NoError(t, err)
require.NotEmpty(t, issued.Code)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
require.Empty(t, consumed.CompletionCodeHash)
require.Nil(t, consumed.CompletionCodeExpiresAt)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
}
func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expired",
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expired",
TTL: time.Second,
})
require.NoError(t, err)
_, err = client.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-open").
SetProviderSubject("union-adoption").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption",
},
})
require.NoError(t, err)
first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
require.NoError(t, err)
require.True(t, first.AdoptDisplayName)
require.False(t, first.AdoptAvatar)
require.Nil(t, first.IdentityID)
second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
})
require.NoError(t, err)
require.Equal(t, first.ID, second.ID)
require.NotNil(t, second.IdentityID)
require.Equal(t, identity.ID, *second.IdentityID)
require.True(t, second.AdoptAvatar)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ReassignsExistingIdentityReference(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption-reassign@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-open").
SetProviderSubject("union-reassign").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
firstSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-reassign",
},
})
require.NoError(t, err)
firstDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: firstSession.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
require.NoError(t, err)
require.NotNil(t, firstDecision.IdentityID)
require.Equal(t, identity.ID, *firstDecision.IdentityID)
secondSession, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-reassign",
},
})
require.NoError(t, err)
secondDecision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: secondSession.ID,
IdentityID: &identity.ID,
AdoptDisplayName: false,
AdoptAvatar: true,
})
require.NoError(t, err)
require.NotNil(t, secondDecision.IdentityID)
require.Equal(t, identity.ID, *secondDecision.IdentityID)
reloadedFirst, err := client.IdentityAdoptionDecision.Get(ctx, firstDecision.ID)
require.NoError(t, err)
require.Nil(t, reloadedFirst.IdentityID)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision_ClearsLegacyNullSessionReference(t *testing.T) {
t.Skip("legacy NULL pending_auth_session_id rows only exist in production PostgreSQL history; sqlite unit schema rejects NULL")
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("legacy-null-session@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-main").
SetProviderSubject("legacy-null-session").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
_, err = client.ExecContext(
ctx,
`INSERT INTO identity_adoption_decisions
(identity_id, adopt_display_name, adopt_avatar, decided_at, created_at, updated_at, pending_auth_session_id)
VALUES (?, ?, ?, ?, ?, ?, NULL)`,
identity.ID,
true,
false,
time.Now().UTC(),
time.Now().UTC(),
time.Now().UTC(),
)
require.NoError(t, err)
legacyDecision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.IdentityIDEQ(identity.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, legacyDecision.IdentityID)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-main",
ProviderSubject: "legacy-null-session",
},
})
require.NoError(t, err)
decision, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: false,
AdoptAvatar: true,
})
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
reloadedLegacy, err := client.IdentityAdoptionDecision.Get(ctx, legacyDecision.ID)
require.NoError(t, err)
require.Nil(t, reloadedLegacy.IdentityID)
}
func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "subject-session-token",
},
BrowserSessionKey: "browser-session",
LocalFlowState: map[string]any{
"completion_response": map[string]any{
"access_token": "token",
},
},
})
require.NoError(t, err)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}
......@@ -13,6 +13,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
......@@ -77,6 +78,12 @@ type DefaultSubscriptionAssigner interface {
AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
}
type signupGrantPlan struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
}
// NewAuthService 创建认证服务实例
func NewAuthService(
entClient *dbent.Client,
......@@ -106,6 +113,13 @@ func NewAuthService(
}
}
func (s *AuthService) EntClient() *dbent.Client {
if s == nil {
return nil
}
return s.entClient
}
// Register 用户注册,返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
......@@ -179,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 获取默认配置
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
grantPlan := s.resolveSignupGrantPlan(ctx, "email")
// 创建用户
user := &User{
Email: email,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
......@@ -205,7 +213,8 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
......@@ -469,21 +478,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 新用户默认值。
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
......@@ -501,7 +505,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
......@@ -520,7 +525,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
......@@ -584,20 +588,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
signupSource := inferLegacySignupSource(email)
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Balance: grantPlan.Balance,
Concurrency: grantPlan.Concurrency,
Status: StatusActive,
}
......@@ -630,7 +630,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
}
} else {
if err := s.userRepo.Create(ctx, newUser); err != nil {
......@@ -646,7 +647,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
s.assignDefaultSubscriptions(ctx, user.ID)
s.postAuthUserBootstrap(ctx, user, signupSource, false)
s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults")
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
return nil, nil, ErrInvitationCodeInvalid
......@@ -670,7 +672,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
......@@ -678,77 +679,270 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
}
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: notes,
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
}
}
}
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan {
plan := signupGrantPlan{}
if s != nil && s.cfg != nil {
plan.Balance = s.cfg.Default.UserBalance
plan.Concurrency = s.cfg.Default.UserConcurrency
}
if s == nil || s.settingService == nil {
return plan
}
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
plan.Balance = s.settingService.GetDefaultBalance(ctx)
plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx)
plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx)
resolved, enabled, err := s.settingService.ResolveAuthSourceGrantSettings(ctx, signupSource, false)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err)
return plan
}
if !enabled {
return plan
}
plan.Balance = resolved.Balance
plan.Concurrency = resolved.Concurrency
plan.Subscriptions = resolved.Subscriptions
return plan
}
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
// while waiting for the user to supply an invitation code.
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
now := time.Now()
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) {
if defaults == nil {
return ProviderDefaultGrantSettings{}, false
}
switch strings.ToLower(strings.TrimSpace(signupSource)) {
case "email":
return defaults.Email, true
case "linuxdo":
return defaults.LinuxDo, true
case "oidc":
return defaults.OIDC, true
case "wechat":
return defaults.WeChat, true
default:
return ProviderDefaultGrantSettings{}, false
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.cfg.JWT.Secret))
}
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
// Returns ErrInvalidToken when the token is invalid or expired.
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
if len(tokenStr) > maxTokenLength {
return "", "", ErrInvalidToken
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if parseErr != nil {
return "", "", ErrInvalidToken
if strings.TrimSpace(signupSource) == "" {
signupSource = "email"
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
}
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
claims, ok := token.Claims.(*pendingOAuthClaims)
if !ok || !token.Valid {
return "", "", ErrInvalidToken
if strings.TrimSpace(signupSource) == "" {
return
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
if err := s.entClient.User.UpdateOneID(userID).
SetSignupSource(signupSource).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
}
return claims.Email, claims.Username, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
items := s.settingService.GetDefaultSubscriptions(ctx)
for _, item := range items {
if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
Notes: "auto assigned by default user subscriptions setting",
}); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
now := time.Now().UTC()
if err := s.entClient.User.UpdateOneID(userID).
SetLastLoginAt(now).
SetLastActiveAt(now).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
}
}
func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) {
if s == nil || user == nil || user.ID <= 0 {
return
}
identity, created := s.ensureEmailAuthIdentity(ctx, user, "auth_service_login_backfill")
if s.shouldApplyEmailFirstBindDefaults(ctx, user.ID, identity, created) {
if err := s.ApplyProviderDefaultSettingsOnFirstBind(ctx, user.ID, "email"); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to apply email first bind defaults: user_id=%d err=%v", user.ID, err)
}
}
}
func (s *AuthService) shouldApplyEmailFirstBindDefaults(
ctx context.Context,
userID int64,
identity *dbent.AuthIdentity,
created bool,
) bool {
source := emailAuthIdentitySource(identity.Metadata)
if source == "auth_service_login_backfill" {
return false
}
if created {
return true
}
if s == nil || s.entClient == nil || userID <= 0 || identity == nil || identity.UserID != userID {
return false
}
if source != "auth_service_dual_write" {
return false
}
hasGrant, err := s.hasProviderGrantRecord(ctx, userID, "email", "first_bind")
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email first bind grant state: user_id=%d err=%v", userID, err)
return false
}
return !hasGrant
}
func emailAuthIdentitySource(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
raw, ok := metadata["source"]
if !ok {
return ""
}
return strings.TrimSpace(fmt.Sprint(raw))
}
func (s *AuthService) hasProviderGrantRecord(
ctx context.Context,
userID int64,
providerType string,
grantReason string,
) (bool, error) {
if s == nil || s.entClient == nil || userID <= 0 {
return false, nil
}
rows, err := s.entClient.QueryContext(
ctx,
`SELECT 1 FROM user_provider_default_grants WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3 LIMIT 1`,
userID,
strings.TrimSpace(providerType),
strings.TrimSpace(grantReason),
)
if err != nil {
return false, err
}
defer func() { _ = rows.Close() }()
return rows.Next(), rows.Err()
}
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, source string) (*dbent.AuthIdentity, bool) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return nil, false
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return nil, false
}
if strings.TrimSpace(source) == "" {
source = "auth_service_dual_write"
}
client := s.entClient
if tx := dbent.TxFromContext(ctx); tx != nil {
client = tx.Client()
}
buildQuery := func() *dbent.AuthIdentityQuery {
return client.AuthIdentity.Query().Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ(email),
)
}
existed, err := buildQuery().Exist(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to inspect email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return nil, false
}
if !existed {
if err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(email).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{
"source": strings.TrimSpace(source),
}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
if isSQLNoRowsError(err) {
return nil, false
}
}
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return nil, false
}
}
identity, err := buildQuery().Only(ctx)
if err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to reload email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
return nil, false
}
if identity.UserID != user.ID {
logger.LegacyPrintf("service.auth", "[Auth] Email auth identity ownership mismatch: user_id=%d email=%s owner_id=%d", user.ID, email, identity.UserID)
return nil, false
}
return identity, !existed
}
func inferLegacySignupSource(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
switch {
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
return "linuxdo"
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
return "oidc"
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
return "wechat"
default:
return "email"
}
}
......@@ -834,7 +1028,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment