Commit b017f461 authored by 陈曦's avatar 陈曦
Browse files

confict fixed by v117

parents 11b97147 d162604f
......@@ -3,6 +3,7 @@ package admin
import (
"strconv"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
......@@ -66,7 +67,7 @@ func (h *PaymentHandler) ListOrders(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, orders, int64(total), page, pageSize)
response.Paginated(c, sanitizeAdminPaymentOrdersForResponse(orders), int64(total), page, pageSize)
}
// GetOrderDetail returns detailed information about a single order.
......@@ -82,7 +83,7 @@ func (h *PaymentHandler) GetOrderDetail(c *gin.Context) {
return
}
auditLogs, _ := h.paymentService.GetOrderAuditLogs(c.Request.Context(), orderID)
response.Success(c, gin.H{"order": order, "auditLogs": auditLogs})
response.Success(c, gin.H{"order": sanitizeAdminPaymentOrderForResponse(order), "auditLogs": auditLogs})
}
// CancelOrder cancels a pending order (admin).
......@@ -114,6 +115,26 @@ func (h *PaymentHandler) RetryFulfillment(c *gin.Context) {
response.Success(c, gin.H{"message": "fulfillment retried"})
}
func sanitizeAdminPaymentOrdersForResponse(orders []*dbent.PaymentOrder) []*dbent.PaymentOrder {
if len(orders) == 0 {
return orders
}
out := make([]*dbent.PaymentOrder, 0, len(orders))
for _, order := range orders {
out = append(out, sanitizeAdminPaymentOrderForResponse(order))
}
return out
}
func sanitizeAdminPaymentOrderForResponse(order *dbent.PaymentOrder) *dbent.PaymentOrder {
if order == nil {
return nil
}
cloned := *order
cloned.ProviderSnapshot = nil
return &cloned
}
// AdminProcessRefundRequest is the request body for admin refund processing.
type AdminProcessRefundRequest struct {
Amount float64 `json:"amount"`
......
......@@ -43,6 +43,15 @@ func scopesContainOpenID(scopes string) bool {
return false
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
......@@ -73,6 +82,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
......@@ -93,114 +107,142 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
paymentCfg = &service.PaymentConfig{}
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
FrontendURL: settings.FrontendURL,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: settings.OIDCConnectEnabled,
OIDCConnectProviderName: settings.OIDCConnectProviderName,
OIDCConnectClientID: settings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
OIDCConnectScopes: settings.OIDCConnectScopes,
OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
BackendModeEnabled: settings.BackendModeEnabled,
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
PaymentDailyLimit: paymentCfg.DailyLimit,
PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
PaymentEnabledTypes: paymentCfg.EnabledTypes,
PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
PaymentHelpImageURL: paymentCfg.HelpImageURL,
PaymentHelpText: paymentCfg.HelpText,
PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
})
payload := dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: settings.PromoCodeEnabled,
PasswordResetEnabled: settings.PasswordResetEnabled,
FrontendURL: settings.FrontendURL,
InvitationCodeEnabled: settings.InvitationCodeEnabled,
TotpEnabled: settings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
WeChatConnectEnabled: settings.WeChatConnectEnabled,
WeChatConnectAppID: settings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
WeChatConnectOpenAppID: settings.WeChatConnectOpenAppID,
WeChatConnectOpenAppSecretConfigured: settings.WeChatConnectOpenAppSecretConfigured,
WeChatConnectMPAppID: settings.WeChatConnectMPAppID,
WeChatConnectMPAppSecretConfigured: settings.WeChatConnectMPAppSecretConfigured,
WeChatConnectMobileAppID: settings.WeChatConnectMobileAppID,
WeChatConnectMobileAppSecretConfigured: settings.WeChatConnectMobileAppSecretConfigured,
WeChatConnectOpenEnabled: settings.WeChatConnectOpenEnabled,
WeChatConnectMPEnabled: settings.WeChatConnectMPEnabled,
WeChatConnectMobileEnabled: settings.WeChatConnectMobileEnabled,
WeChatConnectMode: settings.WeChatConnectMode,
WeChatConnectScopes: settings.WeChatConnectScopes,
WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL,
WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: settings.OIDCConnectEnabled,
OIDCConnectProviderName: settings.OIDCConnectProviderName,
OIDCConnectClientID: settings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: settings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: settings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: settings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: settings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: settings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: settings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: settings.OIDCConnectJWKSURL,
OIDCConnectScopes: settings.OIDCConnectScopes,
OIDCConnectRedirectURL: settings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: settings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: settings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: settings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: settings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: settings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: settings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: settings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: settings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: settings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: settings.OIDCConnectUserInfoUsernamePath,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
HomeContent: settings.HomeContent,
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
TableDefaultPageSize: settings.TableDefaultPageSize,
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
DefaultUserRPMLimit: settings.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
OpsMonitoringEnabled: opsEnabled && settings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
MaxClaudeCodeVersion: settings.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
BackendModeEnabled: settings.BackendModeEnabled,
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource,
PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled,
PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled,
OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount,
PaymentDailyLimit: paymentCfg.DailyLimit,
PaymentOrderTimeoutMin: paymentCfg.OrderTimeoutMin,
PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
PaymentEnabledTypes: paymentCfg.EnabledTypes,
PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
PaymentHelpImageURL: paymentCfg.HelpImageURL,
PaymentHelpText: paymentCfg.HelpText,
PaymentCancelRateLimitEnabled: paymentCfg.CancelRateLimitEnabled,
PaymentCancelRateLimitMax: paymentCfg.CancelRateLimitMax,
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: settings.AvailableChannelsEnabled,
}
response.Success(c, systemSettingsResponseData(payload, authSourceDefaults))
}
// UpdateSettingsRequest 更新设置请求
......@@ -235,6 +277,24 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// WeChat Connect OAuth 登录
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
WeChatConnectAppSecret string `json:"wechat_connect_app_secret"`
WeChatConnectOpenAppID string `json:"wechat_connect_open_app_id"`
WeChatConnectOpenAppSecret string `json:"wechat_connect_open_app_secret"`
WeChatConnectMPAppID string `json:"wechat_connect_mp_app_id"`
WeChatConnectMPAppSecret string `json:"wechat_connect_mp_app_secret"`
WeChatConnectMobileAppID string `json:"wechat_connect_mobile_app_id"`
WeChatConnectMobileAppSecret string `json:"wechat_connect_mobile_app_secret"`
WeChatConnectOpenEnabled bool `json:"wechat_connect_open_enabled"`
WeChatConnectMPEnabled bool `json:"wechat_connect_mp_enabled"`
WeChatConnectMobileEnabled bool `json:"wechat_connect_mobile_enabled"`
WeChatConnectMode string `json:"wechat_connect_mode"`
WeChatConnectScopes string `json:"wechat_connect_scopes"`
WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"`
WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"`
// Generic OIDC OAuth 登录
OIDCConnectEnabled bool `json:"oidc_connect_enabled"`
OIDCConnectProviderName string `json:"oidc_connect_provider_name"`
......@@ -250,8 +310,8 @@ type UpdateSettingsRequest struct {
OIDCConnectRedirectURL string `json:"oidc_connect_redirect_url"`
OIDCConnectFrontendRedirectURL string `json:"oidc_connect_frontend_redirect_url"`
OIDCConnectTokenAuthMethod string `json:"oidc_connect_token_auth_method"`
OIDCConnectUsePKCE bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken bool `json:"oidc_connect_validate_id_token"`
OIDCConnectUsePKCE *bool `json:"oidc_connect_use_pkce"`
OIDCConnectValidateIDToken *bool `json:"oidc_connect_validate_id_token"`
OIDCConnectAllowedSigningAlgs string `json:"oidc_connect_allowed_signing_algs"`
OIDCConnectClockSkewSeconds int `json:"oidc_connect_clock_skew_seconds"`
OIDCConnectRequireEmailVerified bool `json:"oidc_connect_require_email_verified"`
......@@ -276,9 +336,31 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
......@@ -311,6 +393,15 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
// Payment visible method routing
PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"`
// OpenAI account scheduling
OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
......@@ -341,6 +432,13 @@ type UpdateSettingsRequest struct {
PaymentCancelRateLimitWindow *int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
// Channel Monitor feature switch
ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
// Available Channels feature switch (user-facing)
AvailableChannelsEnabled *bool `json:"available_channels_enabled"`
}
// UpdateSettings 更新系统设置
......@@ -357,6 +455,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
......@@ -381,6 +484,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions)
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
......@@ -459,7 +566,141 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
if req.WeChatConnectEnabled {
req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
req.WeChatConnectOpenAppID = strings.TrimSpace(req.WeChatConnectOpenAppID)
req.WeChatConnectOpenAppSecret = strings.TrimSpace(req.WeChatConnectOpenAppSecret)
req.WeChatConnectMPAppID = strings.TrimSpace(req.WeChatConnectMPAppID)
req.WeChatConnectMPAppSecret = strings.TrimSpace(req.WeChatConnectMPAppSecret)
req.WeChatConnectMobileAppID = strings.TrimSpace(req.WeChatConnectMobileAppID)
req.WeChatConnectMobileAppSecret = strings.TrimSpace(req.WeChatConnectMobileAppSecret)
req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode))
req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes)
req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL)
req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL)
req.WeChatConnectAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectRedirectURL, previousSettings.WeChatConnectRedirectURL))
req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.WeChatConnectFrontendRedirectURL, previousSettings.WeChatConnectFrontendRedirectURL))
if req.WeChatConnectMode == "" {
req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(previousSettings.WeChatConnectMode))
}
if req.WeChatConnectScopes == "" {
req.WeChatConnectScopes = strings.TrimSpace(previousSettings.WeChatConnectScopes)
}
if req.WeChatConnectMPEnabled && req.WeChatConnectMobileEnabled {
response.BadRequest(c, "WeChat Official Account and Mobile App cannot be enabled at the same time")
return
}
if req.WeChatConnectMode != "" {
switch req.WeChatConnectMode {
case "open", "mp", "mobile":
default:
response.BadRequest(c, "WeChat mode must be open, mp, or mobile")
return
}
}
if !req.WeChatConnectOpenEnabled && !req.WeChatConnectMPEnabled && !req.WeChatConnectMobileEnabled {
switch req.WeChatConnectMode {
case "mp":
req.WeChatConnectMPEnabled = true
case "mobile":
req.WeChatConnectMobileEnabled = true
default:
req.WeChatConnectOpenEnabled = true
}
}
if req.WeChatConnectMode == "" {
if req.WeChatConnectMPEnabled {
req.WeChatConnectMode = "mp"
} else if req.WeChatConnectMobileEnabled {
req.WeChatConnectMode = "mobile"
} else {
req.WeChatConnectMode = "open"
}
}
req.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectOpenAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMPAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMPAppID, previousSettings.WeChatConnectAppID))
req.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(req.WeChatConnectMobileAppID, req.WeChatConnectAppID, previousSettings.WeChatConnectMobileAppID, previousSettings.WeChatConnectAppID))
if req.WeChatConnectOpenAppSecret == "" {
req.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectOpenAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
}
if req.WeChatConnectMPAppSecret == "" {
req.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMPAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
}
if req.WeChatConnectMobileAppSecret == "" {
req.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(previousSettings.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret, req.WeChatConnectAppSecret))
}
if req.WeChatConnectAppSecret == "" {
req.WeChatConnectAppSecret = strings.TrimSpace(firstNonEmpty(req.WeChatConnectOpenAppSecret, req.WeChatConnectMPAppSecret, req.WeChatConnectMobileAppSecret, previousSettings.WeChatConnectAppSecret))
}
if req.WeChatConnectOpenEnabled {
if req.WeChatConnectOpenAppID == "" {
response.BadRequest(c, "WeChat PC App ID is required when enabled")
return
}
if req.WeChatConnectOpenAppSecret == "" {
response.BadRequest(c, "WeChat PC App Secret is required when enabled")
return
}
}
if req.WeChatConnectMPEnabled {
if req.WeChatConnectMPAppID == "" {
response.BadRequest(c, "WeChat Official Account App ID is required when enabled")
return
}
if req.WeChatConnectMPAppSecret == "" {
response.BadRequest(c, "WeChat Official Account App Secret is required when enabled")
return
}
}
if req.WeChatConnectMobileEnabled {
if req.WeChatConnectMobileAppID == "" {
response.BadRequest(c, "WeChat Mobile App ID is required when enabled")
return
}
if req.WeChatConnectMobileAppSecret == "" {
response.BadRequest(c, "WeChat Mobile App Secret is required when enabled")
return
}
}
if req.WeChatConnectScopes == "" {
if req.WeChatConnectMPEnabled {
req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode("mp")
} else {
req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode)
}
}
if req.WeChatConnectOpenEnabled || req.WeChatConnectMPEnabled {
if req.WeChatConnectRedirectURL == "" {
response.BadRequest(c, "WeChat Redirect URL is required when web oauth is enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil {
response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL")
return
}
if req.WeChatConnectFrontendRedirectURL == "" {
req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback"
}
if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil {
response.BadRequest(c, "WeChat Frontend Redirect URL is invalid")
return
}
}
}
// Generic OIDC 参数验证
oidcUsePKCE, oidcValidateIDToken, err := h.settingService.OIDCSecurityWriteDefaults(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if req.OIDCConnectEnabled {
req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName)
req.OIDCConnectClientID = strings.TrimSpace(req.OIDCConnectClientID)
......@@ -478,10 +719,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(req.OIDCConnectUserInfoEmailPath)
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(req.OIDCConnectUserInfoIDPath)
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(req.OIDCConnectUserInfoUsernamePath)
if req.OIDCConnectProviderName == "" {
req.OIDCConnectProviderName = "OIDC"
req.OIDCConnectProviderName = strings.TrimSpace(firstNonEmpty(req.OIDCConnectProviderName, previousSettings.OIDCConnectProviderName, "OIDC"))
req.OIDCConnectClientID = strings.TrimSpace(firstNonEmpty(req.OIDCConnectClientID, previousSettings.OIDCConnectClientID))
req.OIDCConnectIssuerURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectIssuerURL, previousSettings.OIDCConnectIssuerURL))
req.OIDCConnectDiscoveryURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectDiscoveryURL, previousSettings.OIDCConnectDiscoveryURL))
req.OIDCConnectAuthorizeURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAuthorizeURL, previousSettings.OIDCConnectAuthorizeURL))
req.OIDCConnectTokenURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenURL, previousSettings.OIDCConnectTokenURL))
req.OIDCConnectUserInfoURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoURL, previousSettings.OIDCConnectUserInfoURL))
req.OIDCConnectJWKSURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectJWKSURL, previousSettings.OIDCConnectJWKSURL))
req.OIDCConnectScopes = strings.TrimSpace(firstNonEmpty(req.OIDCConnectScopes, previousSettings.OIDCConnectScopes, "openid email profile"))
req.OIDCConnectRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectRedirectURL, previousSettings.OIDCConnectRedirectURL))
req.OIDCConnectFrontendRedirectURL = strings.TrimSpace(firstNonEmpty(req.OIDCConnectFrontendRedirectURL, previousSettings.OIDCConnectFrontendRedirectURL, "/auth/oidc/callback"))
req.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(firstNonEmpty(req.OIDCConnectTokenAuthMethod, previousSettings.OIDCConnectTokenAuthMethod, "client_secret_post")))
req.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(firstNonEmpty(req.OIDCConnectAllowedSigningAlgs, previousSettings.OIDCConnectAllowedSigningAlgs, "RS256,ES256,PS256"))
req.OIDCConnectUserInfoEmailPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoEmailPath, previousSettings.OIDCConnectUserInfoEmailPath))
req.OIDCConnectUserInfoIDPath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoIDPath, previousSettings.OIDCConnectUserInfoIDPath))
req.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(firstNonEmpty(req.OIDCConnectUserInfoUsernamePath, previousSettings.OIDCConnectUserInfoUsernamePath))
if req.OIDCConnectUsePKCE != nil {
oidcUsePKCE = *req.OIDCConnectUsePKCE
}
if req.OIDCConnectValidateIDToken != nil {
oidcValidateIDToken = *req.OIDCConnectValidateIDToken
}
if req.OIDCConnectClockSkewSeconds == 0 {
req.OIDCConnectClockSkewSeconds = previousSettings.OIDCConnectClockSkewSeconds
if req.OIDCConnectClockSkewSeconds == 0 {
req.OIDCConnectClockSkewSeconds = 120
}
}
if req.OIDCConnectClientID == "" {
response.BadRequest(c, "OIDC Client ID is required when enabled")
return
......@@ -544,19 +810,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none")
return
}
if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE {
response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none")
return
}
if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 {
response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600")
return
}
if req.OIDCConnectValidateIDToken {
if req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
if oidcValidateIDToken && req.OIDCConnectAllowedSigningAlgs == "" {
response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true")
return
}
if req.OIDCConnectJWKSURL != "" {
if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil {
......@@ -805,6 +1065,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
WeChatConnectEnabled: req.WeChatConnectEnabled,
WeChatConnectAppID: req.WeChatConnectAppID,
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
WeChatConnectMode: req.WeChatConnectMode,
WeChatConnectScopes: req.WeChatConnectScopes,
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
......@@ -819,8 +1095,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: req.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: req.OIDCConnectValidateIDToken,
OIDCConnectUsePKCE: oidcUsePKCE,
OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
......@@ -843,6 +1119,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
......@@ -897,6 +1174,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableCCHSigning
}(),
PaymentVisibleMethodAlipaySource: func() string {
if req.PaymentVisibleMethodAlipaySource != nil {
return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource)
}
return previousSettings.PaymentVisibleMethodAlipaySource
}(),
PaymentVisibleMethodWxpaySource: func() string {
if req.PaymentVisibleMethodWxpaySource != nil {
return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource)
}
return previousSettings.PaymentVisibleMethodWxpaySource
}(),
PaymentVisibleMethodAlipayEnabled: func() bool {
if req.PaymentVisibleMethodAlipayEnabled != nil {
return *req.PaymentVisibleMethodAlipayEnabled
}
return previousSettings.PaymentVisibleMethodAlipayEnabled
}(),
PaymentVisibleMethodWxpayEnabled: func() bool {
if req.PaymentVisibleMethodWxpayEnabled != nil {
return *req.PaymentVisibleMethodWxpayEnabled
}
return previousSettings.PaymentVisibleMethodWxpayEnabled
}(),
OpenAIAdvancedSchedulerEnabled: func() bool {
if req.OpenAIAdvancedSchedulerEnabled != nil {
return *req.OpenAIAdvancedSchedulerEnabled
}
return previousSettings.OpenAIAdvancedSchedulerEnabled
}(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
......@@ -927,9 +1234,58 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.AccountQuotaNotifyEmails
}(),
ChannelMonitorEnabled: func() bool {
if req.ChannelMonitorEnabled != nil {
return *req.ChannelMonitorEnabled
}
return previousSettings.ChannelMonitorEnabled
}(),
ChannelMonitorDefaultIntervalSeconds: func() int {
if req.ChannelMonitorDefaultIntervalSeconds != nil {
return *req.ChannelMonitorDefaultIntervalSeconds
}
return previousSettings.ChannelMonitorDefaultIntervalSeconds
}(),
AvailableChannelsEnabled: func() bool {
if req.AvailableChannelsEnabled != nil {
return *req.AvailableChannelsEnabled
}
return previousSettings.AvailableChannelsEnabled
}(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
authSourceDefaults := &service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind),
},
LinuxDo: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind),
},
OIDC: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind),
},
WeChat: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
response.ErrorFrom(c, err)
return
}
......@@ -969,7 +1325,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
h.auditSettingsUpdate(c, previousSettings, settings, req)
h.auditSettingsUpdate(c, previousSettings, settings, previousAuthSourceDefaults, authSourceDefaults, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
......@@ -977,6 +1333,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
......@@ -994,113 +1355,141 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
updatedPaymentCfg = &service.PaymentConfig{}
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
FrontendURL: updatedSettings.FrontendURL,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
BackendModeEnabled: updatedSettings.BackendModeEnabled,
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
PaymentHelpText: updatedPaymentCfg.HelpText,
PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
})
payload := dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
FrontendURL: updatedSettings.FrontendURL,
InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled,
TotpEnabled: updatedSettings.TotpEnabled,
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
WeChatConnectOpenAppID: updatedSettings.WeChatConnectOpenAppID,
WeChatConnectOpenAppSecretConfigured: updatedSettings.WeChatConnectOpenAppSecretConfigured,
WeChatConnectMPAppID: updatedSettings.WeChatConnectMPAppID,
WeChatConnectMPAppSecretConfigured: updatedSettings.WeChatConnectMPAppSecretConfigured,
WeChatConnectMobileAppID: updatedSettings.WeChatConnectMobileAppID,
WeChatConnectMobileAppSecretConfigured: updatedSettings.WeChatConnectMobileAppSecretConfigured,
WeChatConnectOpenEnabled: updatedSettings.WeChatConnectOpenEnabled,
WeChatConnectMPEnabled: updatedSettings.WeChatConnectMPEnabled,
WeChatConnectMobileEnabled: updatedSettings.WeChatConnectMobileEnabled,
WeChatConnectMode: updatedSettings.WeChatConnectMode,
WeChatConnectScopes: updatedSettings.WeChatConnectScopes,
WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL,
WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled,
OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName,
OIDCConnectClientID: updatedSettings.OIDCConnectClientID,
OIDCConnectClientSecretConfigured: updatedSettings.OIDCConnectClientSecretConfigured,
OIDCConnectIssuerURL: updatedSettings.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: updatedSettings.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: updatedSettings.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: updatedSettings.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: updatedSettings.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: updatedSettings.OIDCConnectJWKSURL,
OIDCConnectScopes: updatedSettings.OIDCConnectScopes,
OIDCConnectRedirectURL: updatedSettings.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: updatedSettings.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: updatedSettings.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: updatedSettings.OIDCConnectUsePKCE,
OIDCConnectValidateIDToken: updatedSettings.OIDCConnectValidateIDToken,
OIDCConnectAllowedSigningAlgs: updatedSettings.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: updatedSettings.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: updatedSettings.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: updatedSettings.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: updatedSettings.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: updatedSettings.OIDCConnectUserInfoUsernamePath,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
HomeContent: updatedSettings.HomeContent,
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
TableDefaultPageSize: updatedSettings.TableDefaultPageSize,
TablePageSizeOptions: updatedSettings.TablePageSizeOptions,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
MaxClaudeCodeVersion: updatedSettings.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
BackendModeEnabled: updatedSettings.BackendModeEnabled,
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource,
PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource,
PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled,
PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled,
OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
PaymentDailyLimit: updatedPaymentCfg.DailyLimit,
PaymentOrderTimeoutMin: updatedPaymentCfg.OrderTimeoutMin,
PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
PaymentHelpImageURL: updatedPaymentCfg.HelpImageURL,
PaymentHelpText: updatedPaymentCfg.HelpText,
PaymentCancelRateLimitEnabled: updatedPaymentCfg.CancelRateLimitEnabled,
PaymentCancelRateLimitMax: updatedPaymentCfg.CancelRateLimitMax,
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled,
}
response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults))
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
......@@ -1117,12 +1506,12 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
}
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
changed := diffSettings(before, after, req)
changed := diffSettings(before, after, beforeAuthSourceDefaults, afterAuthSourceDefaults, req)
if len(changed) == 0 {
return
}
......@@ -1137,7 +1526,7 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
)
}
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
......@@ -1205,6 +1594,54 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
changed = append(changed, "wechat_connect_enabled")
}
if before.WeChatConnectAppID != after.WeChatConnectAppID {
changed = append(changed, "wechat_connect_app_id")
}
if req.WeChatConnectAppSecret != "" {
changed = append(changed, "wechat_connect_app_secret")
}
if before.WeChatConnectOpenAppID != after.WeChatConnectOpenAppID {
changed = append(changed, "wechat_connect_open_app_id")
}
if req.WeChatConnectOpenAppSecret != "" {
changed = append(changed, "wechat_connect_open_app_secret")
}
if before.WeChatConnectMPAppID != after.WeChatConnectMPAppID {
changed = append(changed, "wechat_connect_mp_app_id")
}
if req.WeChatConnectMPAppSecret != "" {
changed = append(changed, "wechat_connect_mp_app_secret")
}
if before.WeChatConnectMobileAppID != after.WeChatConnectMobileAppID {
changed = append(changed, "wechat_connect_mobile_app_id")
}
if req.WeChatConnectMobileAppSecret != "" {
changed = append(changed, "wechat_connect_mobile_app_secret")
}
if before.WeChatConnectOpenEnabled != after.WeChatConnectOpenEnabled {
changed = append(changed, "wechat_connect_open_enabled")
}
if before.WeChatConnectMPEnabled != after.WeChatConnectMPEnabled {
changed = append(changed, "wechat_connect_mp_enabled")
}
if before.WeChatConnectMobileEnabled != after.WeChatConnectMobileEnabled {
changed = append(changed, "wechat_connect_mobile_enabled")
}
if before.WeChatConnectMode != after.WeChatConnectMode {
changed = append(changed, "wechat_connect_mode")
}
if before.WeChatConnectScopes != after.WeChatConnectScopes {
changed = append(changed, "wechat_connect_scopes")
}
if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL {
changed = append(changed, "wechat_connect_redirect_url")
}
if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL {
changed = append(changed, "wechat_connect_frontend_redirect_url")
}
if before.OIDCConnectEnabled != after.OIDCConnectEnabled {
changed = append(changed, "oidc_connect_enabled")
}
......@@ -1376,6 +1813,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource {
changed = append(changed, "payment_visible_method_alipay_source")
}
if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource {
changed = append(changed, "payment_visible_method_wxpay_source")
}
if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled {
changed = append(changed, "payment_visible_method_alipay_enabled")
}
if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled {
changed = append(changed, "payment_visible_method_wxpay_enabled")
}
if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
changed = append(changed, "openai_advanced_scheduler_enabled")
}
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
......@@ -1392,6 +1844,59 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
if before.ChannelMonitorEnabled != after.ChannelMonitorEnabled {
changed = append(changed, "channel_monitor_enabled")
}
if before.ChannelMonitorDefaultIntervalSeconds != after.ChannelMonitorDefaultIntervalSeconds {
changed = append(changed, "channel_monitor_default_interval_seconds")
}
if before.AvailableChannelsEnabled != after.AvailableChannelsEnabled {
changed = append(changed, "available_channels_enabled")
}
changed = appendAuthSourceDefaultChanges(changed, beforeAuthSourceDefaults, afterAuthSourceDefaults)
return changed
}
func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSourceDefaultSettings, after *service.AuthSourceDefaultSettings) []string {
if before == nil {
before = &service.AuthSourceDefaultSettings{}
}
if after == nil {
after = &service.AuthSourceDefaultSettings{}
}
type providerDefaultGrantField struct {
name string
before service.ProviderDefaultGrantSettings
after service.ProviderDefaultGrantSettings
}
fields := []providerDefaultGrantField{
{name: "email", before: before.Email, after: after.Email},
{name: "linuxdo", before: before.LinuxDo, after: after.LinuxDo},
{name: "oidc", before: before.OIDC, after: after.OIDC},
{name: "wechat", before: before.WeChat, after: after.WeChat},
}
for _, field := range fields {
if field.before.Balance != field.after.Balance {
changed = append(changed, "auth_source_default_"+field.name+"_balance")
}
if field.before.Concurrency != field.after.Concurrency {
changed = append(changed, "auth_source_default_"+field.name+"_concurrency")
}
if !equalDefaultSubscriptions(field.before.Subscriptions, field.after.Subscriptions) {
changed = append(changed, "auth_source_default_"+field.name+"_subscriptions")
}
if field.before.GrantOnSignup != field.after.GrantOnSignup {
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_signup")
}
if field.before.GrantOnFirstBind != field.after.GrantOnFirstBind {
changed = append(changed, "auth_source_default_"+field.name+"_grant_on_first_bind")
}
}
if before.ForceEmailOnThirdPartySignup != after.ForceEmailOnThirdPartySignup {
changed = append(changed, "force_email_on_third_party_signup")
}
return changed
}
......@@ -1412,6 +1917,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto
return normalized
}
func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting {
if input == nil {
return nil
}
normalized := normalizeDefaultSubscriptions(*input)
return &normalized
}
func float64ValueOrDefault(value *float64, fallback float64) float64 {
if value == nil {
return fallback
}
return *value
}
func intValueOrDefault(value *int, fallback int) int {
if value == nil {
return fallback
}
return *value
}
func boolValueOrDefault(value *bool, fallback bool) bool {
if value == nil {
return fallback
}
return *value
}
func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting {
if input == nil {
return fallback
}
result := make([]service.DefaultSubscriptionSetting, 0, len(*input))
for _, item := range *input {
result = append(result, service.DefaultSubscriptionSetting{
GroupID: item.GroupID,
ValidityDays: item.ValidityDays,
})
}
return result
}
func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any {
data := make(map[string]any)
raw, err := json.Marshal(settings)
if err == nil {
_ = json.Unmarshal(raw, &data)
}
if authSourceDefaults == nil {
authSourceDefaults = &service.AuthSourceDefaultSettings{}
}
data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance
data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency
data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions
data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup
data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind
data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance
data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup
data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind
data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance
data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency
data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions
data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup
data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind
data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup
return data
}
func equalStringSlice(a, b []string) bool {
if len(a) != len(b) {
return false
......
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type settingHandlerRepoStub struct {
values map[string]string
lastUpdates map[string]string
}
func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.lastUpdates = make(map[string]string, len(settings))
for key, value := range settings {
s.lastUpdates[key] = value
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for key, value := range s.values {
out[key] = value
}
return out, nil
}
func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
type failingAuthSourceSettingsRepoStub struct {
values map[string]string
err error
}
func (s *failingAuthSourceSettingsRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *failingAuthSourceSettingsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *failingAuthSourceSettingsRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *failingAuthSourceSettingsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *failingAuthSourceSettingsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
if _, ok := settings[service.SettingKeyAuthSourceDefaultEmailBalance]; ok {
return s.err
}
for key, value := range settings {
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *failingAuthSourceSettingsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for key, value := range s.values {
out[key] = value
}
return out, nil
}
func (s *failingAuthSourceSettingsRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil)
handler.GetSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, 9.5, data["auth_source_default_email_balance"])
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any)
require.True(t, ok)
require.Len(t, subscriptions, 1)
}
func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
"promo_code_enabled": true,
"auth_source_default_email_balance": 12.75,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency])
require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions])
require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup])
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, 12.75, data["auth_source_default_email_balance"])
require.Equal(t, float64(8), data["auth_source_default_email_concurrency"])
require.Equal(t, true, data["force_email_on_third_party_signup"])
}
func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"payment_visible_method_alipay_source": "easypay",
"payment_visible_method_wxpay_source": "wxpay",
"payment_visible_method_alipay_enabled": true,
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource])
require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource])
require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled])
require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"])
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"])
require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"])
require.Equal(t, true, data["payment_visible_method_alipay_enabled"])
require.Equal(t, false, data["payment_visible_method_wxpay_enabled"])
require.Equal(t, true, data["openai_advanced_scheduler_enabled"])
}
func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodSource(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingPaymentVisibleMethodAlipayEnabled: "true",
service.SettingPaymentVisibleMethodAlipaySource: "",
service.SettingPaymentVisibleMethodWxpayEnabled: "false",
service.SettingPaymentVisibleMethodWxpaySource: "",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": false,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "", repo.values[service.SettingPaymentVisibleMethodAlipaySource])
require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled])
}
func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFlags(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "oidc-client",
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectUsePKCE: "true",
service.SettingKeyOIDCConnectValidateIDToken: "true",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"oidc_connect_enabled": true,
"oidc_connect_use_pkce": false,
"oidc_connect_validate_id_token": false,
"oidc_connect_allowed_signing_algs": "",
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["oidc_connect_use_pkce"])
require.Equal(t, false, data["oidc_connect_validate_id_token"])
}
func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaultsOnLegacyUpgrade(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: "OIDC",
service.SettingKeyOIDCConnectClientID: "oidc-client",
service.SettingKeyOIDCConnectClientSecret: "oidc-secret",
service.SettingKeyOIDCConnectIssuerURL: "https://issuer.example.com",
service.SettingKeyOIDCConnectAuthorizeURL: "https://issuer.example.com/auth",
service.SettingKeyOIDCConnectTokenURL: "https://issuer.example.com/token",
service.SettingKeyOIDCConnectUserInfoURL: "https://issuer.example.com/userinfo",
service.SettingKeyOIDCConnectJWKSURL: "https://issuer.example.com/jwks",
service.SettingKeyOIDCConnectScopes: "openid email profile",
service.SettingKeyOIDCConnectRedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256",
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: "false",
service.SettingKeyOIDCConnectUserInfoEmailPath: "",
service.SettingKeyOIDCConnectUserInfoIDPath: "",
service.SettingKeyOIDCConnectUserInfoUsernamePath: "",
},
}
svc := service.NewSettingService(repo, &config.Config{
Default: config.DefaultConfig{UserConcurrency: 5},
OIDC: config.OIDCConnectConfig{
Enabled: true,
ProviderName: "OIDC",
ClientID: "oidc-client",
ClientSecret: "oidc-secret",
IssuerURL: "https://issuer.example.com",
AuthorizeURL: "https://issuer.example.com/auth",
TokenURL: "https://issuer.example.com/token",
UserInfoURL: "https://issuer.example.com/userinfo",
JWKSURL: "https://issuer.example.com/jwks",
Scopes: "openid email profile",
RedirectURL: "https://example.com/api/v1/auth/oauth/oidc/callback",
FrontendRedirectURL: "/auth/oidc/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
ValidateIDToken: true,
AllowedSigningAlgs: "RS256",
ClockSkewSeconds: 120,
},
})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"oidc_connect_enabled": true,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectUsePKCE])
require.Equal(t, "false", repo.values[service.SettingKeyOIDCConnectValidateIDToken])
}
func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &settingHandlerRepoStub{
values: map[string]string{
service.SettingKeyPromoCodeEnabled: "true",
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
"payment_visible_method_alipay_source": "bogus",
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource)
}
func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAuthSourceDefaultsFail(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := &failingAuthSourceSettingsRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "false",
service.SettingKeyPromoCodeEnabled: "true",
service.SettingKeyAuthSourceDefaultEmailBalance: "9.5",
service.SettingKeyAuthSourceDefaultEmailConcurrency: "8",
service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`,
},
err: errors.New("write auth source defaults failed"),
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
"promo_code_enabled": true,
"auth_source_default_email_balance": 12.75,
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusInternalServerError, rec.Code)
require.Equal(t, "false", repo.values[service.SettingKeyRegistrationEnabled])
require.Equal(t, "9.5", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance])
}
func TestDiffSettings_IncludesAuthSourceDefaultsAndForceEmail(t *testing.T) {
changed := diffSettings(
&service.SystemSettings{},
&service.SystemSettings{},
&service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: 0,
Concurrency: 5,
Subscriptions: nil,
GrantOnSignup: true,
GrantOnFirstBind: false,
},
ForceEmailOnThirdPartySignup: false,
},
&service.AuthSourceDefaultSettings{
Email: service.ProviderDefaultGrantSettings{
Balance: 12.5,
Concurrency: 7,
Subscriptions: []service.DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 30}},
GrantOnSignup: false,
GrantOnFirstBind: true,
},
ForceEmailOnThirdPartySignup: true,
},
UpdateSettingsRequest{},
)
require.Contains(t, changed, "auth_source_default_email_balance")
require.Contains(t, changed, "auth_source_default_email_concurrency")
require.Contains(t, changed, "auth_source_default_email_subscriptions")
require.Contains(t, changed, "auth_source_default_email_grant_on_signup")
require.Contains(t, changed, "auth_source_default_email_grant_on_first_bind")
require.Contains(t, changed, "force_email_on_third_party_signup")
}
......@@ -40,6 +40,7 @@ type CreateUserRequest struct {
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
RPMLimit int `json:"rpm_limit"`
AllowedGroups []int64 `json:"allowed_groups"`
}
......@@ -52,6 +53,7 @@ type UpdateUserRequest struct {
Notes *string `json:"notes"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
RPMLimit *int `json:"rpm_limit"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
......@@ -66,6 +68,22 @@ type UpdateBalanceRequest struct {
Notes string `json:"notes"`
}
type BindUserAuthIdentityRequest struct {
ProviderType string `json:"provider_type"`
ProviderKey string `json:"provider_key"`
ProviderSubject string `json:"provider_subject"`
Issuer *string `json:"issuer"`
Metadata map[string]any `json:"metadata"`
Channel *BindUserAuthIdentityChannelRequest `json:"channel"`
}
type BindUserAuthIdentityChannelRequest struct {
Channel string `json:"channel"`
ChannelAppID string `json:"channel_app_id"`
ChannelSubject string `json:"channel_subject"`
Metadata map[string]any `json:"metadata"`
}
// List handles listing all users with pagination
// GET /api/v1/admin/users
// Query params:
......@@ -172,6 +190,45 @@ func (h *UserHandler) GetByID(c *gin.Context) {
response.Success(c, dto.UserFromServiceAdmin(user))
}
// BindAuthIdentity manually binds a canonical auth identity to a user.
// POST /api/v1/admin/users/:id/auth-identities
func (h *UserHandler) BindAuthIdentity(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req BindUserAuthIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
input := service.AdminBindAuthIdentityInput{
ProviderType: req.ProviderType,
ProviderKey: req.ProviderKey,
ProviderSubject: req.ProviderSubject,
Issuer: req.Issuer,
Metadata: req.Metadata,
}
if req.Channel != nil {
input.Channel = &service.AdminBindAuthIdentityChannelInput{
Channel: req.Channel.Channel,
ChannelAppID: req.Channel.ChannelAppID,
ChannelSubject: req.Channel.ChannelSubject,
Metadata: req.Channel.Metadata,
}
}
result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
......@@ -188,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) {
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
RPMLimit: req.RPMLimit,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
......@@ -221,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
RPMLimit: req.RPMLimit,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
......@@ -400,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) {
"migrated_keys": result.MigratedKeys,
})
}
// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量
// GET /api/v1/admin/users/:id/rpm-status
func (h *UserHandler) GetUserRPMStatus(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, status)
}
//go:build unit
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestUserHandlerListIncludesActivityFieldsAndSortParams(t *testing.T) {
gin.SetMode(gin.TestMode)
lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
lastActiveAt := lastLoginAt.Add(30 * time.Minute)
lastUsedAt := lastLoginAt.Add(90 * time.Minute)
adminSvc := newStubAdminService()
adminSvc.users = []service.User{
{
ID: 7,
Email: "activity@example.com",
Username: "activity-user",
Role: service.RoleUser,
Status: service.StatusActive,
LastActiveAt: &lastActiveAt,
LastUsedAt: &lastUsedAt,
CreatedAt: lastLoginAt.Add(-24 * time.Hour),
UpdatedAt: lastLoginAt,
},
}
handler := NewUserHandler(adminSvc, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(
http.MethodGet,
"/api/v1/admin/users?sort_by=last_used_at&sort_order=asc&search=activity",
nil,
)
handler.List(c)
require.Equal(t, http.StatusOK, recorder.Code)
require.Equal(t, "last_used_at", adminSvc.lastListUsers.sortBy)
require.Equal(t, "asc", adminSvc.lastListUsers.sortOrder)
require.Equal(t, "activity", adminSvc.lastListUsers.filters.Search)
var resp struct {
Code int `json:"code"`
Data struct {
Items []struct {
LastActiveAt *time.Time `json:"last_active_at"`
LastUsedAt *time.Time `json:"last_used_at"`
} `json:"items"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Len(t, resp.Data.Items, 1)
require.WithinDuration(t, lastActiveAt, *resp.Data.Items[0].LastActiveAt, time.Second)
require.WithinDuration(t, lastUsedAt, *resp.Data.Items[0].LastUsedAt, time.Second)
}
func TestUserHandlerGetByIDIncludesActivityFields(t *testing.T) {
gin.SetMode(gin.TestMode)
lastLoginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC)
lastActiveAt := lastLoginAt.Add(30 * time.Minute)
lastUsedAt := lastLoginAt.Add(90 * time.Minute)
adminSvc := newStubAdminService()
adminSvc.users = []service.User{
{
ID: 8,
Email: "detail@example.com",
Username: "detail-user",
Role: service.RoleUser,
Status: service.StatusActive,
LastActiveAt: &lastActiveAt,
LastUsedAt: &lastUsedAt,
CreatedAt: lastLoginAt.Add(-24 * time.Hour),
UpdatedAt: lastLoginAt,
},
}
handler := NewUserHandler(adminSvc, nil)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Params = gin.Params{{Key: "id", Value: "8"}}
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/8", nil)
handler.GetByID(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data struct {
LastActiveAt *time.Time `json:"last_active_at"`
LastUsedAt *time.Time `json:"last_used_at"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.WithinDuration(t, lastActiveAt, *resp.Data.LastActiveAt, time.Second)
require.WithinDuration(t, lastUsedAt, *resp.Data.LastUsedAt, time.Second)
}
//go:build unit
package handler
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAuthHandlerGetCurrentUserReturnsProfileCompatibilityFields(t *testing.T) {
gin.SetMode(gin.TestMode)
verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC)
repo := &userHandlerRepoStub{
user: &service.User{
ID: 31,
Email: "me@example.com",
Username: "linuxdo-handle",
Role: service.RoleUser,
Status: service.StatusActive,
AvatarURL: "https://cdn.example.com/linuxdo.png",
AvatarSource: "remote_url",
},
identities: []service.UserAuthIdentityRecord{
{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "linuxdo-subject-31",
VerifiedAt: &verifiedAt,
Metadata: map[string]any{
"username": "linuxdo-handle",
"avatar_url": "https://cdn.example.com/linuxdo.png",
},
},
},
}
handler := &AuthHandler{
userService: service.NewUserService(repo, nil, nil, nil),
}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/me", nil)
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 31})
handler.GetCurrentUser(c)
require.Equal(t, http.StatusOK, recorder.Code)
var resp struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
}
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, true, resp.Data["email_bound"])
require.Equal(t, true, resp.Data["linuxdo_bound"])
require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"])
authBindings, ok := resp.Data["auth_bindings"].(map[string]any)
require.True(t, ok)
linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, linuxdoBinding["bound"])
avatarSource, ok := resp.Data["avatar_source"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", avatarSource["provider"])
require.Equal(t, "linuxdo", avatarSource["source"])
profileSources, ok := resp.Data["profile_sources"].(map[string]any)
require.True(t, ok)
usernameSource, ok := profileSources["username"].(map[string]any)
require.True(t, ok)
require.Equal(t, "linuxdo", usernameSource["provider"])
require.Equal(t, "linuxdo", usernameSource["source"])
}
package handler
import (
"context"
"log/slog"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
......@@ -76,9 +78,24 @@ type AuthResponse struct {
User *dto.User `json:"user"`
}
func ensureLoginUserActive(user *service.User) error {
if user == nil {
return infraerrors.Unauthorized("INVALID_USER", "user not found")
}
if !user.IsActive() {
return service.ErrUserNotActive
}
return nil
}
// respondWithTokenPair 生成 Token 对并返回认证响应
// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容)
func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
if err := ensureLoginUserActive(user); err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
slog.Error("failed to generate token pair", "error", err, "user_id", user.ID)
......@@ -104,6 +121,34 @@ func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) {
})
}
func (h *AuthHandler) ensureBackendModeAllowsUser(ctx context.Context, user *service.User) error {
if user == nil {
return infraerrors.Unauthorized("INVALID_USER", "user not found")
}
if h == nil || !h.isBackendModeEnabled(ctx) || user.IsAdmin() {
return nil
}
return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
}
func (h *AuthHandler) ensureBackendModeAllowsNewUserLogin(ctx context.Context) error {
if h == nil || !h.isBackendModeEnabled(ctx) {
return nil
}
return infraerrors.Forbidden("BACKEND_MODE_ADMIN_ONLY", "Backend mode is active. Only admin login is allowed.")
}
func (h *AuthHandler) isBackendModeEnabled(ctx context.Context) bool {
if h == nil || h.settingSvc == nil {
return false
}
settings, err := h.settingSvc.GetPublicSettings(ctx)
if err == nil && settings != nil {
return settings.BackendModeEnabled
}
return h.settingSvc.IsBackendModeEnabled(ctx)
}
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
......@@ -177,6 +222,11 @@ func (h *AuthHandler) Login(c *gin.Context) {
}
_ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
// Check if TOTP 2FA is enabled for this user
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
// Create a temporary login session for 2FA
......@@ -194,11 +244,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
// Backend mode: only admin can login
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
h.respondWithTokenPair(c, user)
}
......@@ -262,16 +308,80 @@ func (h *AuthHandler) Login2FA(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
if err := ensureLoginUserActive(user); err != nil {
response.ErrorFrom(c, err)
return
}
// Backend mode: only admin can login (check BEFORE deleting session)
if h.settingSvc.IsBackendModeEnabled(c.Request.Context()) && !user.IsAdmin() {
response.Forbidden(c, "Backend mode is active. Only admin login is allowed.")
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
if session.PendingOAuthBind != nil {
pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
pendingSession, err := pendingSvc.GetBrowserSession(
c.Request.Context(),
session.PendingOAuthBind.PendingSessionToken,
session.PendingOAuthBind.BrowserSessionKey,
)
if err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, pendingSession.ID, oauthAdoptionDecisionRequest{})
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthBinding(
c.Request.Context(),
h.entClient(),
h.authService,
h.userService,
pendingSession,
decision,
&user.ID,
true,
true,
); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
if _, err := pendingSvc.ConsumeBrowserSession(
c.Request.Context(),
pendingSession.SessionToken,
pendingSession.BrowserSessionKey,
); err != nil {
response.ErrorFrom(c, err)
return
}
secureCookie := isRequestHTTPS(c)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
user, err = h.userService.GetByID(c.Request.Context(), session.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
}
// Delete the login session (only after all checks pass)
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
if session.PendingOAuthBind == nil {
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
}
h.respondWithTokenPair(c, user)
}
......@@ -290,8 +400,14 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
return
}
identities, err := h.userService.GetProfileIdentitySummaries(c.Request.Context(), subject.UserID, user)
if err != nil {
response.ErrorFrom(c, err)
return
}
type UserResponse struct {
*dto.User
userProfileResponse
RunMode string `json:"run_mode"`
}
......@@ -300,7 +416,10 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
runMode = h.cfg.RunMode
}
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
response.Success(c, UserResponse{
userProfileResponse: userProfileResponseFromService(user, identities),
RunMode: runMode,
})
}
// ValidatePromoCodeRequest 验证优惠码请求
......@@ -578,6 +697,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
// 不影响登出流程
}
}
h.consumePendingOAuthSessionOnLogout(c)
clearOAuthLogoutCookies(c)
response.Success(c, LogoutResponse{
Message: "Logged out successfully",
......@@ -598,7 +719,7 @@ func (h *AuthHandler) RevokeAllSessions(c *gin.Context) {
return
}
if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil {
if err := h.authService.RevokeAllUserTokens(c.Request.Context(), subject.UserID); err != nil {
slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err)
response.InternalError(c, "Failed to revoke sessions")
return
......
......@@ -2,6 +2,8 @@ package handler
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
......@@ -13,10 +15,13 @@ import (
"time"
"unicode/utf8"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
......@@ -25,17 +30,24 @@ import (
)
const (
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
linuxDoOAuthDefaultRedirectTo = "/dashboard"
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo"
oauthBindAccessTokenCookiePath = "/api/v1/auth/oauth"
linuxDoOAuthStateCookieName = "linuxdo_oauth_state"
linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier"
linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect"
linuxDoOAuthIntentCookieName = "linuxdo_oauth_intent"
linuxDoOAuthBindUserCookieName = "linuxdo_oauth_bind_user"
oauthBindAccessTokenCookieName = "oauth_bind_access_token"
linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes
linuxDoOAuthDefaultRedirectTo = "/dashboard"
linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback"
linuxDoOAuthMaxRedirectLen = 2048
linuxDoOAuthMaxFragmentValueLen = 512
linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-")
oauthIntentLogin = "login"
oauthIntentBindCurrentUser = "bind_current_user"
)
type linuxDoTokenResponse struct {
......@@ -87,9 +99,29 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
browserSessionKey, err := generateOAuthPendingBrowserSession()
if err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
return
}
secureCookie := isRequestHTTPS(c)
setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie)
intent := normalizeOAuthIntent(c.Query("intent"))
setCookie(c, linuxDoOAuthIntentCookieName, encodeCookieValue(intent), linuxDoOAuthCookieMaxAgeSec, secureCookie)
setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
clearOAuthPendingSessionCookie(c, secureCookie)
if intent == oauthIntentBindCurrentUser {
bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
if err != nil {
response.ErrorFrom(c, err)
return
}
setCookie(c, linuxDoOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), linuxDoOAuthCookieMaxAgeSec, secureCookie)
} else {
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}
codeChallenge := ""
if cfg.UsePKCE {
......@@ -148,6 +180,8 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
}()
expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName)
......@@ -161,6 +195,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
if redirectTo == "" {
redirectTo = linuxDoOAuthDefaultRedirectTo
}
browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
if strings.TrimSpace(browserSessionKey) == "" {
redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "")
return
}
intent, _ := readCookieDecoded(c, linuxDoOAuthIntentCookieName)
intent = normalizeOAuthIntent(intent)
codeVerifier := ""
if cfg.UsePKCE {
......@@ -198,52 +239,204 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) {
return
}
email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp)
if err != nil {
log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err)
redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "")
return
}
compatEmail := strings.TrimSpace(email)
// 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。
// 统一使用基于 subject 的稳定合成邮箱来做账号绑定。
if subject != "" {
email = linuxDoSyntheticEmail(subject)
}
identityKey := service.PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: subject,
}
upstreamClaims := map[string]any{
"email": email,
"username": username,
"subject": subject,
"suggested_display_name": displayName,
"suggested_avatar_url": avatarURL,
}
if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) {
upstreamClaims["compat_email"] = compatEmail
}
if intent == oauthIntentBindCurrentUser {
targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName)
if err != nil {
redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth bind target", "")
return
}
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentBindCurrentUser,
Identity: identityKey,
TargetUserID: &targetUserID,
ResolvedEmail: email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth bind", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
return
}
// 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "")
existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey)
if err != nil {
if errors.Is(err, service.ErrOAuthInvitationRequired) {
pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username)
if tokenErr != nil {
redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "")
return
}
fragment := url.Values{}
fragment.Set("error", "invitation_required")
fragment.Set("pending_oauth_token", pendingToken)
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if existingIdentityUser != nil {
if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identityKey,
TargetUserID: &existingIdentityUser.ID,
ResolvedEmail: existingIdentityUser.Email,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: map[string]any{
"redirect": redirectTo,
},
}); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
// 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。
redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err))
redirectToFrontendCallback(c, frontendCallback)
return
}
fragment := url.Values{}
fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
redirectWithFragment(c, frontendCallback, fragment)
compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail)
if err != nil {
redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
return
}
if err := h.createLinuxDoOAuthChoicePendingSession(
c,
identityKey,
email,
email,
redirectTo,
browserSessionKey,
upstreamClaims,
compatEmail,
compatEmailUser,
h.isForceEmailOnThirdPartySignup(c.Request.Context()),
); err != nil {
redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "")
return
}
redirectToFrontendCallback(c, frontendCallback)
}
func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
email = strings.TrimSpace(strings.ToLower(email))
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
return nil, nil
}
userEntity, err := client.User.Query().
Where(userNormalizedEmailPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
}
switch len(userEntity) {
case 0:
return nil, nil
case 1:
return userEntity[0], nil
default:
return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
}
}
func (h *AuthHandler) createLinuxDoOAuthChoicePendingSession(
c *gin.Context,
identity service.PendingAuthIdentityKey,
suggestedEmail string,
resolvedEmail string,
redirectTo string,
browserSessionKey string,
upstreamClaims map[string]any,
compatEmail string,
compatEmailUser *dbent.User,
forceEmailOnSignup bool,
) error {
suggestionEmail := strings.TrimSpace(suggestedEmail)
canonicalEmail := strings.TrimSpace(resolvedEmail)
if suggestionEmail == "" {
suggestionEmail = canonicalEmail
}
completionResponse := map[string]any{
"step": oauthPendingChoiceStep,
"adoption_required": true,
"redirect": strings.TrimSpace(redirectTo),
"email": suggestionEmail,
"resolved_email": canonicalEmail,
"existing_account_email": "",
"existing_account_bindable": false,
"create_account_allowed": true,
"force_email_on_signup": forceEmailOnSignup,
"choice_reason": "third_party_signup",
}
if strings.TrimSpace(compatEmail) != "" {
completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
}
resolvedChoiceEmail := suggestionEmail
if compatEmailUser != nil {
completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
completionResponse["existing_account_bindable"] = true
completionResponse["choice_reason"] = "compat_email_match"
resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
}
if forceEmailOnSignup && compatEmailUser == nil {
completionResponse["choice_reason"] = "force_email_on_signup"
}
var targetUserID *int64
if compatEmailUser != nil && compatEmailUser.ID > 0 {
targetUserID = &compatEmailUser.ID
}
return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
Intent: oauthIntentLogin,
Identity: identity,
TargetUserID: targetUserID,
ResolvedEmail: resolvedChoiceEmail,
RedirectTo: redirectTo,
BrowserSessionKey: browserSessionKey,
UpstreamIdentityClaims: upstreamClaims,
CompletionResponse: completionResponse,
})
}
type completeLinuxDoOAuthRequest struct {
PendingOAuthToken string `json:"pending_oauth_token" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
InvitationCode string `json:"invitation_code" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
// CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating
......@@ -256,17 +449,87 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
return
}
email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken)
secureCookie := isRequestHTTPS(c)
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
return
}
pendingSvc, err := h.pendingIdentityService()
if err != nil {
response.ErrorFrom(c, err)
return
}
session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"})
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
response.ErrorFrom(c, err)
return
} else if handled {
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
return
} else {
session = updatedSession
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
email := strings.TrimSpace(session.ResolvedEmail)
username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
if email == "" || username == "" {
response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
AdoptDisplayName: req.AdoptDisplayName,
AdoptAvatar: req.AdoptAvatar,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
......@@ -303,7 +566,7 @@ func linuxDoExchangeCode(
form.Set("client_id", cfg.ClientID)
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
if cfg.UsePKCE {
if strings.TrimSpace(codeVerifier) != "" {
form.Set("code_verifier", codeVerifier)
}
......@@ -353,11 +616,11 @@ func linuxDoFetchUserInfo(
ctx context.Context,
cfg config.LinuxDoConnectConfig,
token *linuxDoTokenResponse,
) (email string, username string, subject string, err error) {
) (email string, username string, subject string, displayName string, avatarURL string, err error) {
client := req.C().SetTimeout(30 * time.Second)
authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken)
if err != nil {
return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err)
}
resp, err := client.R().
......@@ -366,16 +629,16 @@ func linuxDoFetchUserInfo(
SetHeader("Authorization", authorization).
Get(cfg.UserInfoURL)
if err != nil {
return "", "", "", fmt.Errorf("request userinfo: %w", err)
return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err)
}
if !resp.IsSuccessState() {
return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode)
}
return linuxDoParseUserInfo(resp.String(), cfg)
}
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) {
func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) {
email = firstNonEmpty(
getGJSON(body, cfg.UserInfoEmailPath),
getGJSON(body, "email"),
......@@ -400,12 +663,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
getGJSON(body, "user.id"),
)
displayName = firstNonEmpty(
getGJSON(body, "name"),
getGJSON(body, "nickname"),
getGJSON(body, "display_name"),
getGJSON(body, "user.name"),
getGJSON(body, "user.username"),
username,
)
avatarURL = firstNonEmpty(
getGJSON(body, "avatar_url"),
getGJSON(body, "avatar"),
getGJSON(body, "picture"),
getGJSON(body, "profile_image_url"),
getGJSON(body, "user.avatar"),
getGJSON(body, "user.avatar_url"),
)
subject = strings.TrimSpace(subject)
if subject == "" {
return "", "", "", errors.New("userinfo missing id field")
return "", "", "", "", "", errors.New("userinfo missing id field")
}
if !isSafeLinuxDoSubject(subject) {
return "", "", "", errors.New("userinfo returned invalid id field")
return "", "", "", "", "", errors.New("userinfo returned invalid id field")
}
email = strings.TrimSpace(email)
......@@ -418,8 +698,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s
if username == "" {
username = "linuxdo_" + subject
}
displayName = strings.TrimSpace(displayName)
if displayName == "" {
displayName = username
}
avatarURL = strings.TrimSpace(avatarURL)
return email, username, subject, nil
return email, username, subject, displayName, avatarURL, nil
}
func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) {
......@@ -436,7 +721,7 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod
q.Set("scope", cfg.Scopes)
}
q.Set("state", state)
if cfg.UsePKCE {
if strings.TrimSpace(codeChallenge) != "" {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
}
......@@ -670,6 +955,30 @@ func clearCookie(c *gin.Context, name string, secure bool) {
})
}
func clearOAuthBindAccessTokenCookie(c *gin.Context, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthBindAccessTokenCookieName,
Value: "",
Path: oauthBindAccessTokenCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func setOAuthBindAccessTokenCookie(c *gin.Context, token string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthBindAccessTokenCookieName,
Value: url.QueryEscape(strings.TrimSpace(token)),
Path: oauthBindAccessTokenCookiePath,
MaxAge: linuxDoOAuthCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func truncateFragmentValue(value string) string {
value = strings.TrimSpace(value)
if value == "" {
......@@ -728,3 +1037,127 @@ func linuxDoSyntheticEmail(subject string) string {
}
return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain
}
func normalizeOAuthIntent(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "", oauthIntentLogin:
return oauthIntentLogin
case "bind", oauthIntentBindCurrentUser:
return oauthIntentBindCurrentUser
default:
return oauthIntentLogin
}
}
func (h *AuthHandler) buildOAuthBindUserCookieFromContext(c *gin.Context) (string, error) {
userID, err := h.resolveOAuthBindTargetUserID(c)
if err != nil || userID == nil || *userID <= 0 {
return "", infraerrors.Unauthorized("UNAUTHORIZED", "authentication required")
}
return buildOAuthBindUserCookieValue(*userID, h.oauthBindCookieSecret())
}
func (h *AuthHandler) PrepareOAuthBindAccessTokenCookie(c *gin.Context) {
const bearerPrefix = "Bearer "
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
if !strings.HasPrefix(strings.ToLower(authHeader), strings.ToLower(bearerPrefix)) {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
token := strings.TrimSpace(authHeader[len(bearerPrefix):])
if token == "" {
response.ErrorFrom(c, infraerrors.Unauthorized("UNAUTHORIZED", "authentication required"))
return
}
setOAuthBindAccessTokenCookie(c, token, isRequestHTTPS(c))
c.Status(http.StatusNoContent)
c.Writer.WriteHeaderNow()
}
func (h *AuthHandler) resolveOAuthBindTargetUserID(c *gin.Context) (*int64, error) {
if subject, ok := servermiddleware.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 {
return &subject.UserID, nil
}
if h == nil || h.authService == nil || h.userService == nil {
return nil, service.ErrInvalidToken
}
ck, err := c.Request.Cookie(oauthBindAccessTokenCookieName)
clearOAuthBindAccessTokenCookie(c, isRequestHTTPS(c))
if err != nil {
return nil, err
}
tokenString, err := url.QueryUnescape(strings.TrimSpace(ck.Value))
if err != nil {
return nil, err
}
if tokenString == "" {
return nil, service.ErrInvalidToken
}
claims, err := h.authService.ValidateToken(tokenString)
if err != nil {
return nil, err
}
user, err := h.userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
return nil, err
}
if user == nil || !user.IsActive() || claims.TokenVersion != user.TokenVersion {
return nil, service.ErrInvalidToken
}
return &user.ID, nil
}
func (h *AuthHandler) readOAuthBindUserIDFromCookie(c *gin.Context, cookieName string) (int64, error) {
value, err := readCookieDecoded(c, cookieName)
if err != nil {
return 0, err
}
return parseOAuthBindUserCookieValue(value, h.oauthBindCookieSecret())
}
func (h *AuthHandler) oauthBindCookieSecret() string {
if h == nil || h.cfg == nil {
return ""
}
return strings.TrimSpace(h.cfg.JWT.Secret)
}
func buildOAuthBindUserCookieValue(userID int64, secret string) (string, error) {
secret = strings.TrimSpace(secret)
if userID <= 0 || secret == "" {
return "", errors.New("invalid oauth bind cookie input")
}
payload := strconv.FormatInt(userID, 10)
mac := hmac.New(sha256.New, []byte(secret))
_, _ = mac.Write([]byte(payload))
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return payload + "." + signature, nil
}
func parseOAuthBindUserCookieValue(value string, secret string) (int64, error) {
secret = strings.TrimSpace(secret)
if secret == "" {
return 0, errors.New("missing oauth bind cookie secret")
}
payload, signature, ok := strings.Cut(strings.TrimSpace(value), ".")
if !ok || payload == "" || signature == "" {
return 0, errors.New("invalid oauth bind cookie")
}
mac := hmac.New(sha256.New, []byte(secret))
_, _ = mac.Write([]byte(payload))
expectedSignature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
if !hmac.Equal([]byte(signature), []byte(expectedSignature)) {
return 0, errors.New("invalid oauth bind cookie signature")
}
userID, err := strconv.ParseInt(payload, 10, 64)
if err != nil || userID <= 0 {
return 0, errors.New("invalid oauth bind cookie user")
}
return userID, nil
}
package handler
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/config"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
......@@ -41,11 +55,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg)
email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "alice", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
require.Equal(t, "Alice", displayName)
require.Equal(t, "https://cdn.example/avatar.png", avatarURL)
}
func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
......@@ -53,11 +69,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg)
require.NoError(t, err)
require.Equal(t, "123", subject)
require.Equal(t, "linuxdo_123", username)
require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email)
require.Equal(t, "linuxdo_123", displayName)
require.Equal(t, "", avatarURL)
}
func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
......@@ -65,11 +83,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) {
UserInfoURL: "https://connect.linux.do/api/user",
}
_, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
_, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg)
require.Error(t, err)
tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1)
_, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
_, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg)
require.Error(t, err)
}
......@@ -106,3 +124,906 @@ func TestSingleLineStripsWhitespace(t *testing.T) {
require.Equal(t, "hello world", singleLine("hello\r\nworld"))
require.Equal(t, "", singleLine("\n\t\r"))
}
func TestLinuxDoOAuthBindStartRedirectsAndSetsBindCookies(t *testing.T) {
handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: "https://connect.linux.do/oauth/authorize",
TokenURL: "https://connect.linux.do/oauth/token",
UserInfoURL: "https://connect.linux.do/api/user",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/bind/start?intent=bind_current_user&redirect=/settings/connections", nil)
c.Request = req
c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 42})
handler.LinuxDoOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.Contains(t, location, "connect.linux.do/oauth/authorize")
require.Contains(t, location, "client_id=linuxdo-client")
require.Contains(t, location, "code_challenge=")
cookies := recorder.Result().Cookies()
require.NotNil(t, findCookie(cookies, linuxDoOAuthStateCookieName))
require.NotNil(t, findCookie(cookies, linuxDoOAuthRedirectCookie))
require.NotNil(t, findCookie(cookies, linuxDoOAuthVerifierCookie))
require.NotNil(t, findCookie(cookies, oauthPendingBrowserCookieName))
intentCookie := findCookie(cookies, linuxDoOAuthIntentCookieName)
require.NotNil(t, intentCookie)
require.Equal(t, oauthIntentBindCurrentUser, decodeCookieValueForTest(t, intentCookie.Value))
bindCookie := findCookie(cookies, linuxDoOAuthBindUserCookieName)
require.NotNil(t, bindCookie)
userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
require.NoError(t, err)
require.Equal(t, int64(42), userID)
}
func TestLinuxDoOAuthStartOmitsPKCEWhenDisabled(t *testing.T) {
handler := newLinuxDoOAuthTestHandler(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: "https://connect.linux.do/oauth/authorize",
TokenURL: "https://connect.linux.do/oauth/token",
UserInfoURL: "https://connect.linux.do/api/user",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?redirect=/dashboard", nil)
handler.LinuxDoOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.NotContains(t, recorder.Header().Get("Location"), "code_challenge=")
require.Nil(t, findCookie(recorder.Result().Cookies(), linuxDoOAuthVerifierCookie))
}
func TestLinuxDoOAuthCallbackAllowsMissingVerifierWhenPKCEDisabled(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
require.NoError(t, r.ParseForm())
require.Empty(t, r.PostForm.Get("code_verifier"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"compat-subject","username":"linuxdo_user","name":"LinuxDo Display"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: false,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=linuxdo-code&state=state-123", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
require.NotNil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
}
func TestLinuxDoOAuthBindStartAcceptsAccessTokenCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: "https://connect.linux.do/oauth/authorize",
TokenURL: "https://connect.linux.do/oauth/token",
UserInfoURL: "https://connect.linux.do/api/user",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
user, err := client.User.Create().
SetEmail("bind-cookie@example.com").
SetUsername("bind-cookie-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(context.Background())
require.NoError(t, err)
token, err := handler.authService.GenerateToken(&service.User{
ID: user.ID,
Email: user.Email,
Username: user.Username,
PasswordHash: user.PasswordHash,
Role: user.Role,
Status: user.Status,
})
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/start?intent=bind_current_user&redirect=/settings/connections", nil)
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: token, Path: oauthBindAccessTokenCookiePath})
c.Request = req
handler.LinuxDoOAuthStart(c)
require.Equal(t, http.StatusFound, recorder.Code)
bindCookie := findCookie(recorder.Result().Cookies(), linuxDoOAuthBindUserCookieName)
require.NotNil(t, bindCookie)
userID, err := parseOAuthBindUserCookieValue(decodeCookieValueForTest(t, bindCookie.Value), "test-secret")
require.NoError(t, err)
require.Equal(t, user.ID, userID)
accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
require.NotNil(t, accessTokenCookie)
require.Equal(t, -1, accessTokenCookie.MaxAge)
}
func TestPrepareOAuthBindAccessTokenCookieSetsHttpOnlyCookie(t *testing.T) {
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/bind-token", nil)
req.Header.Set("Authorization", "Bearer access-token-value")
c.Request = req
handler.PrepareOAuthBindAccessTokenCookie(c)
require.Equal(t, http.StatusNoContent, recorder.Code)
accessTokenCookie := findCookie(recorder.Result().Cookies(), oauthBindAccessTokenCookieName)
require.NotNil(t, accessTokenCookie)
require.Equal(t, oauthBindAccessTokenCookiePath, accessTokenCookie.Path)
require.Equal(t, linuxDoOAuthCookieMaxAgeSec, accessTokenCookie.MaxAge)
require.True(t, accessTokenCookie.HttpOnly)
require.Equal(t, url.QueryEscape("access-token-value"), accessTokenCookie.Value)
}
func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingIdentityUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"321","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(linuxDoSyntheticEmail("321")).
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("321").
SetMetadata(map[string]any{"username": "legacy-user"}).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-123&state=state-123", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-123"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-123"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, linuxDoSyntheticEmail("321"), session.ResolvedEmail)
require.Equal(t, "LinuxDo Display", session.UpstreamIdentityClaims["suggested_display_name"])
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
_, hasRefreshToken := completion["refresh_token"]
require.False(t, hasRefreshToken)
require.Nil(t, completion["error"])
}
func TestLinuxDoOAuthCallbackRejectsDisabledExistingIdentityUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_disabled","name":"LinuxDo Disabled"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(linuxDoSyntheticEmail("654")).
SetUsername("disabled-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusDisabled).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingUser.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("654").
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-disabled&state=state-disabled", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-disabled"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-disabled"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-disabled"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName))
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "session_error", "USER_NOT_ACTIVE")
count, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, count)
}
func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail(" Legacy@Example.com ").
SetUsername("legacy-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, existingUser.ID, *session.TargetUserID)
require.Equal(t, strings.TrimSpace(existingUser.Email), session.ResolvedEmail)
require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"])
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/dashboard", completion["redirect"])
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["email"])
require.Equal(t, strings.TrimSpace(existingUser.Email), completion["existing_account_email"])
require.Equal(t, true, completion["existing_account_bindable"])
require.Equal(t, "compat_email_match", completion["choice_reason"])
_, hasAccessToken := completion["access_token"]
require.False(t, hasAccessToken)
}
func TestLinuxDoOAuthCallbackCreatesChoicePendingSessionWhenSignupRequiresInvite(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"654","username":"linuxdo_invite","name":"Need Invite","avatar_url":"https://cdn.example/invite.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, true, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-456&state=state-456", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-456"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-456"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-456"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
ctx := context.Background()
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentLogin, session.Intent)
require.Nil(t, session.TargetUserID)
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, oauthPendingChoiceStep, completion["step"])
require.Equal(t, "/dashboard", completion["redirect"])
require.Equal(t, "third_party_signup", completion["choice_reason"])
}
func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCurrentUser(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`))
case "/userinfo":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"999","username":"bind_user","name":"Bind Display","avatar_url":"https://cdn.example/bind.png"}`))
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{
Enabled: true,
ClientID: "linuxdo-client",
ClientSecret: "linuxdo-secret",
AuthorizeURL: upstream.URL + "/authorize",
TokenURL: upstream.URL + "/token",
UserInfoURL: upstream.URL + "/userinfo",
Scopes: "read",
RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback",
FrontendRedirectURL: "/auth/linuxdo/callback",
TokenAuthMethod: "client_secret_post",
UsePKCE: true,
})
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
currentUser, err := client.User.Create().
SetEmail("current@example.com").
SetUsername("current-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-bind&state=state-bind", nil)
req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-bind"))
req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/settings/connections"))
req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-bind"))
req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentBindCurrentUser))
req.AddCookie(encodedCookie(linuxDoOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret")))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-bind"))
c.Request = req
handler.LinuxDoOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location"))
sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)
require.NotNil(t, sessionCookie)
session, err := client.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))).
Only(ctx)
require.NoError(t, err)
require.Equal(t, oauthIntentBindCurrentUser, session.Intent)
require.NotNil(t, session.TargetUserID)
require.Equal(t, currentUser.ID, *session.TargetUserID)
require.Equal(t, linuxDoSyntheticEmail("999"), session.ResolvedEmail)
completion, ok := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any)
require.True(t, ok)
require.Equal(t, "/settings/connections", completion["redirect"])
require.Empty(t, completion["access_token"])
require.Equal(t, "Bind Display", session.UpstreamIdentityClaims["suggested_display_name"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Equal(t, 1, userCount)
}
func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-subject-1").
SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "LinuxDo Display",
"suggested_avatar_url": "https://cdn.example/linuxdo.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptAvatar: true,
})
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "LinuxDo Display", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("linuxdo-subject-1"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"])
require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"])
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.True(t, decision.AdoptDisplayName)
require.True(t, decision.AdoptAvatar)
consumed, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationRejectsAdoptExistingUserSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingUser, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-invalid-session").
SetIntent("adopt_existing_user_by_email").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-invalid-subject-1").
SetTargetUserID(existingUser.ID).
SetResolvedEmail(existingUser.Email).
SetBrowserSessionKey("linuxdo-invalid-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": "bind_login_required",
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-invalid-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusBadRequest, recorder.Code)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationReturnsPendingSessionWhenChoiceStillRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-choice-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-choice-subject-1").
SetResolvedEmail("linuxdo-choice-subject-1@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-choice-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetLocalFlowState(map[string]any{
oauthCompletionResponseKey: map[string]any{
"step": oauthPendingChoiceStep,
"redirect": "/dashboard",
"email": "fresh@example.com",
"resolved_email": "fresh@example.com",
"force_email_on_signup": true,
},
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-choice-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.Equal(t, "pending_session", responseData["auth_result"])
require.Equal(t, oauthPendingChoiceStep, responseData["step"])
require.Equal(t, true, responseData["force_email_on_signup"])
require.Empty(t, responseData["access_token"])
userCount, err := client.User.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func TestCompleteLinuxDoOAuthRegistrationBindsIdentityWithoutAdoptionFlags(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-no-adoption-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-subject-no-adoption").
SetResolvedEmail("linuxdo-subject-no-adoption@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-browser-no-adoption").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
"suggested_display_name": "LinuxDo Legacy",
"suggested_avatar_url": "https://cdn.example/linuxdo-legacy.png",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser-no-adoption")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusOK, recorder.Code)
responseData := decodeJSONBody(t, recorder)
require.NotEmpty(t, responseData["access_token"])
require.NotEmpty(t, responseData["refresh_token"])
userEntity, err := client.User.Query().
Where(dbuser.EmailEQ(session.ResolvedEmail)).
Only(ctx)
require.NoError(t, err)
require.Equal(t, "linuxdo_user", userEntity.Username)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("linuxdo"),
authidentity.ProviderKeyEQ("linuxdo"),
authidentity.ProviderSubjectEQ("linuxdo-subject-no-adoption"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, userEntity.ID, identity.UserID)
decision, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, decision.IdentityID)
require.Equal(t, identity.ID, *decision.IdentityID)
require.False(t, decision.AdoptDisplayName)
require.False(t, decision.AdoptAvatar)
}
func TestCompleteLinuxDoOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUserCreation(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
existingOwner, err := client.User.Create().
SetEmail("owner@example.com").
SetUsername("owner-user").
SetPasswordHash("hash").
SetRole(service.RoleUser).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
_, err = client.AuthIdentity.Create().
SetUserID(existingOwner.ID).
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
Save(ctx)
require.NoError(t, err)
session, err := client.PendingAuthSession.Create().
SetSessionToken("linuxdo-complete-conflict-session").
SetIntent("login").
SetProviderType("linuxdo").
SetProviderKey("linuxdo").
SetProviderSubject("linuxdo-conflict-subject").
SetResolvedEmail("linuxdo-conflict-subject@linuxdo-connect.invalid").
SetBrowserSessionKey("linuxdo-conflict-browser").
SetUpstreamIdentityClaims(map[string]any{
"username": "linuxdo_user",
}).
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
body := bytes.NewBufferString(`{"invitation_code":"invite-1"}`)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body)
req.Header.Set("Content-Type", "application/json")
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-conflict-browser")})
c.Request = req
handler.CompleteLinuxDoOAuthRegistration(c)
require.Equal(t, http.StatusConflict, recorder.Code)
payload := decodeJSONBody(t, recorder)
require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", payload["reason"])
userCount, err := client.User.Query().
Where(dbuser.EmailEQ("linuxdo-conflict-subject@linuxdo-connect.invalid")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
storedSession, err := client.PendingAuthSession.Get(ctx, session.ID)
require.NoError(t, err)
require.Nil(t, storedSession.ConsumedAt)
}
func newLinuxDoOAuthTestHandler(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) *AuthHandler {
t.Helper()
handler, _ := newLinuxDoOAuthHandlerAndClient(t, invitationEnabled, oauthCfg)
return handler
}
func newLinuxDoOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg config.LinuxDoConnectConfig) (*AuthHandler, *dbent.Client) {
t.Helper()
handler, client := newOAuthPendingFlowTestHandler(t, invitationEnabled)
handler.settingSvc = nil
handler.cfg = &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
AccessTokenExpireMinutes: 60,
RefreshTokenExpireDays: 7,
},
LinuxDo: oauthCfg,
}
return handler, client
}
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestLogoutClearsOAuthStateCookiesAndConsumesPendingSession(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
ctx := context.Background()
session, err := client.PendingAuthSession.Create().
SetSessionToken("logout-pending-session-token").
SetIntent("login").
SetProviderType("oidc").
SetProviderKey("https://issuer.example").
SetProviderSubject("logout-subject-123").
SetBrowserSessionKey("logout-browser-session-key").
SetResolvedEmail("logout@example.com").
SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)).
Save(ctx)
require.NoError(t, err)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/logout", nil)
req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)})
req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("logout-browser-session-key")})
req.AddCookie(&http.Cookie{Name: oauthBindAccessTokenCookieName, Value: "bind-access-token"})
req.AddCookie(&http.Cookie{Name: linuxDoOAuthStateCookieName, Value: encodeCookieValue("linuxdo-state")})
req.AddCookie(&http.Cookie{Name: oidcOAuthStateCookieName, Value: encodeCookieValue("oidc-state")})
req.AddCookie(&http.Cookie{Name: wechatOAuthStateCookieName, Value: encodeCookieValue("wechat-state")})
req.AddCookie(&http.Cookie{Name: wechatPaymentOAuthStateName, Value: encodeCookieValue("wechat-payment-state")})
ginCtx.Request = req
handler.Logout(ginCtx)
require.Equal(t, http.StatusOK, recorder.Code)
cookies := recorder.Result().Cookies()
for _, name := range []string{
oauthPendingSessionCookieName,
oauthPendingBrowserCookieName,
oauthBindAccessTokenCookieName,
linuxDoOAuthStateCookieName,
oidcOAuthStateCookieName,
wechatOAuthStateCookieName,
wechatPaymentOAuthStateName,
} {
cookie := findCookie(cookies, name)
require.NotNil(t, cookie, name)
require.Equal(t, -1, cookie.MaxAge, name)
require.True(t, cookie.HttpOnly, name)
}
storedSession, err := client.PendingAuthSession.Query().
Where(pendingauthsession.IDEQ(session.ID)).
Only(ctx)
require.NoError(t, err)
require.NotNil(t, storedSession.ConsumedAt)
}
package handler
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/authidentitychannel"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/predicate"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
entsql "entgo.io/ent/dialect/sql"
"github.com/gin-gonic/gin"
)
const (
oauthPendingBrowserCookiePath = "/api/v1/auth/oauth"
oauthPendingBrowserCookieName = "oauth_pending_browser_session"
oauthPendingSessionCookiePath = "/api/v1/auth/oauth"
oauthPendingSessionCookieName = "oauth_pending_session"
oauthPendingCookieMaxAgeSec = 10 * 60
oauthPendingChoiceStep = "choose_account_action_required"
oauthCompletionResponseKey = "completion_response"
)
var pendingOAuthCreateAccountPreCommitHook func(context.Context, *dbent.PendingAuthSession) error
type oauthPendingSessionPayload struct {
Intent string
Identity service.PendingAuthIdentityKey
TargetUserID *int64
ResolvedEmail string
RedirectTo string
BrowserSessionKey string
UpstreamIdentityClaims map[string]any
CompletionResponse map[string]any
}
type oauthAdoptionDecisionRequest struct {
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
type bindPendingOAuthLoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
type createPendingOAuthAccountRequest struct {
Email string `json:"email" binding:"required,email"`
VerifyCode string `json:"verify_code,omitempty"`
Password string `json:"password" binding:"required,min=6"`
InvitationCode string `json:"invitation_code,omitempty"`
AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
}
type sendPendingOAuthVerifyCodeRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token,omitempty"`
PendingAuthToken string `json:"pending_auth_token,omitempty"`
PendingOAuthToken string `json:"pending_oauth_token,omitempty"`
}
func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName,
AdoptAvatar: r.AdoptAvatar,
}
}
func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest {
return oauthAdoptionDecisionRequest{
AdoptDisplayName: r.AdoptDisplayName,
AdoptAvatar: r.AdoptAvatar,
}
}
func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) {
if h == nil || h.authService == nil || h.authService.EntClient() == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil
}
func generateOAuthPendingBrowserSession() (string, error) {
return oauth.GenerateState()
}
func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthPendingBrowserCookieName,
Value: encodeCookieValue(sessionKey),
Path: oauthPendingBrowserCookiePath,
MaxAge: oauthPendingCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthPendingBrowserCookieName,
Value: "",
Path: oauthPendingBrowserCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) {
return readCookieDecoded(c, oauthPendingBrowserCookieName)
}
func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthPendingSessionCookieName,
Value: encodeCookieValue(sessionToken),
Path: oauthPendingSessionCookiePath,
MaxAge: oauthPendingCookieMaxAgeSec,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) {
http.SetCookie(c.Writer, &http.Cookie{
Name: oauthPendingSessionCookieName,
Value: "",
Path: oauthPendingSessionCookiePath,
MaxAge: -1,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
})
}
func readOAuthPendingSessionCookie(c *gin.Context) (string, error) {
return readCookieDecoded(c, oauthPendingSessionCookieName)
}
func redirectToFrontendCallback(c *gin.Context, frontendCallback string) {
u, err := url.Parse(frontendCallback)
if err != nil {
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") {
c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo)
return
}
u.Fragment = ""
c.Header("Cache-Control", "no-store")
c.Header("Pragma", "no-cache")
c.Redirect(http.StatusFound, u.String())
}
func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error {
svc, err := h.pendingIdentityService()
if err != nil {
return err
}
session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{
Intent: strings.TrimSpace(payload.Intent),
Identity: payload.Identity,
TargetUserID: payload.TargetUserID,
ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail),
RedirectTo: strings.TrimSpace(payload.RedirectTo),
BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey),
UpstreamIdentityClaims: payload.UpstreamIdentityClaims,
LocalFlowState: map[string]any{
oauthCompletionResponseKey: payload.CompletionResponse,
},
})
if err != nil {
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
}
setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c))
return nil
}
func readCompletionResponse(session map[string]any) (map[string]any, bool) {
if len(session) == 0 {
return nil, false
}
value, ok := session[oauthCompletionResponseKey]
if !ok {
return nil, false
}
result, ok := value.(map[string]any)
if !ok {
return nil, false
}
return result, true
}
func clonePendingMap(values map[string]any) map[string]any {
if len(values) == 0 {
return map[string]any{}
}
cloned := make(map[string]any, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any {
payload, _ := readCompletionResponse(session.LocalFlowState)
merged := clonePendingMap(payload)
if strings.TrimSpace(session.RedirectTo) != "" {
if _, exists := merged["redirect"]; !exists {
merged["redirect"] = session.RedirectTo
}
}
for key, value := range overrides {
if value == nil {
delete(merged, key)
continue
}
merged[key] = value
}
applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims)
return merged
}
func pendingSessionStringValue(values map[string]any, key string) string {
if len(values) == 0 {
return ""
}
raw, ok := values[key]
if !ok {
return ""
}
value, ok := raw.(string)
if !ok {
return ""
}
return strings.TrimSpace(value)
}
func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
if session == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(session.Intent), oauthIntentLogin) {
return false
}
if session.TargetUserID == nil || *session.TargetUserID <= 0 {
return false
}
if pendingSessionWantsInvitation(payload) {
return false
}
return strings.TrimSpace(pendingSessionStringValue(payload, "step")) == ""
}
func ensurePendingOAuthCompleteRegistrationSession(session *dbent.PendingAuthSession) error {
if session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
if strings.TrimSpace(session.Intent) != oauthIntentLogin {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
if session.TargetUserID != nil && *session.TargetUserID > 0 {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload, _ := readCompletionResponse(session.LocalFlowState)
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
return nil
}
func buildLegacyCompleteRegistrationPendingResponse(
session *dbent.PendingAuthSession,
forceEmailOnSignup bool,
emailVerificationRequired bool,
) map[string]any {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, map[string]any{
"step": oauthPendingChoiceStep,
"adoption_required": true,
"create_account_allowed": true,
"force_email_on_signup": forceEmailOnSignup,
}))
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
if _, exists := completionResponse["email"]; !exists {
completionResponse["email"] = email
}
if _, exists := completionResponse["resolved_email"]; !exists {
completionResponse["resolved_email"] = email
}
}
if _, exists := completionResponse["choice_reason"]; !exists {
switch {
case forceEmailOnSignup:
completionResponse["choice_reason"] = "force_email_on_signup"
case emailVerificationRequired:
completionResponse["choice_reason"] = "email_verification_required"
default:
completionResponse["choice_reason"] = "third_party_signup"
}
}
return completionResponse
}
func (h *AuthHandler) legacyCompleteRegistrationSessionStatus(
c *gin.Context,
session *dbent.PendingAuthSession,
) (*dbent.PendingAuthSession, bool, error) {
if session == nil {
return nil, false, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
payload := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
if step := pendingSessionStringValue(payload, "step"); step != "" {
return session, true, nil
}
emailVerificationRequired := h != nil && h.authService != nil && h.authService.IsEmailVerifyEnabled(c.Request.Context())
forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
if !emailVerificationRequired && !forceEmailOnSignup {
return session, false, nil
}
client := h.entClient()
if client == nil {
return nil, false, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
updatedSession, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
strings.TrimSpace(session.ResolvedEmail),
nil,
buildLegacyCompleteRegistrationPendingResponse(session, forceEmailOnSignup, emailVerificationRequired),
)
if err != nil {
return nil, false, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
}
return updatedSession, true, nil
}
func (r oauthAdoptionDecisionRequest) hasDecision() bool {
return r.AdoptDisplayName != nil || r.AdoptAvatar != nil
}
func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) {
var req oauthAdoptionDecisionRequest
if c == nil || c.Request == nil || c.Request.Body == nil {
return req, nil
}
if err := c.ShouldBindJSON(&req); err != nil {
if errors.Is(err, io.EOF) {
return req, nil
}
return req, err
}
return req, nil
}
func cloneOAuthMetadata(values map[string]any) map[string]any {
if len(values) == 0 {
return map[string]any{}
}
cloned := make(map[string]any, len(values))
for key, value := range values {
cloned[key] = value
}
return cloned
}
func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any {
merged := cloneOAuthMetadata(base)
for key, value := range overlay {
merged[key] = value
}
return merged
}
func normalizeAdoptedOAuthDisplayName(value string) string {
value = strings.TrimSpace(value)
if len([]rune(value)) > 100 {
value = string([]rune(value)[:100])
}
return value
}
func (h *AuthHandler) entClient() *dbent.Client {
if h == nil || h.authService == nil {
return nil
}
return h.authService.EntClient()
}
func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool {
if h == nil || h.settingSvc == nil {
return false
}
defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx)
if err != nil || defaults == nil {
return false
}
return defaults.ForceEmailOnThirdPartySignup
}
func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
record, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
return findActiveUserByID(ctx, client, record.UserID)
}
func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") }
func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") }
func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") }
func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") }
func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "linuxdo")
}
func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") }
func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "wechat")
}
func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) {
h.createPendingOAuthAccount(c, "")
}
// SendPendingOAuthVerifyCode sends a verification code for a browser-bound
// pending OAuth account-creation flow.
// POST /api/v1/auth/oauth/pending/send-verify-code
func (h *AuthHandler) SendPendingOAuthVerifyCode(c *gin.Context) {
var req sendPendingOAuthVerifyCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
response.ErrorFrom(c, err)
return
}
_, session, _, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
email := strings.TrimSpace(strings.ToLower(req.Email))
if existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email); err == nil && existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
} else if err != nil && !errors.Is(err, service.ErrUserNotFound) {
response.ErrorFrom(c, err)
return
}
result, err := h.authService.SendPendingOAuthVerifyCode(c.Request.Context(), req.Email)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, SendVerifyCodeResponse{
Message: "Verification code sent successfully",
Countdown: result.Countdown,
})
}
func (h *AuthHandler) upsertPendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
req oauthAdoptionDecisionRequest,
) (*dbent.IdentityAdoptionDecision, error) {
client := h.entClient()
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
existing, err := client.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)).
Only(c.Request.Context())
if err != nil && !dbent.IsNotFound(err) {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err)
}
if existing != nil && !req.hasDecision() {
return existing, nil
}
if existing == nil && !req.hasDecision() {
return nil, nil
}
input := service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
}
if existing != nil {
input.AdoptDisplayName = existing.AdoptDisplayName
input.AdoptAvatar = existing.AdoptAvatar
input.IdentityID = existing.IdentityID
}
if req.AdoptDisplayName != nil {
input.AdoptDisplayName = *req.AdoptDisplayName
}
if req.AdoptAvatar != nil {
input.AdoptAvatar = *req.AdoptAvatar
}
svc, err := h.pendingIdentityService()
if err != nil {
return nil, err
}
decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input)
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return decision, nil
}
func (h *AuthHandler) ensurePendingOAuthAdoptionDecision(
c *gin.Context,
sessionID int64,
req oauthAdoptionDecisionRequest,
) (*dbent.IdentityAdoptionDecision, error) {
decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req)
if err != nil {
return nil, err
}
if decision != nil {
return decision, nil
}
svc, err := h.pendingIdentityService()
if err != nil {
return nil, err
}
decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: sessionID,
})
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err)
}
return decision, nil
}
func updatePendingOAuthSessionProgress(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
intent string,
resolvedEmail string,
targetUserID *int64,
completionResponse map[string]any,
) (*dbent.PendingAuthSession, error) {
if client == nil || session == nil {
return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
}
localFlowState := clonePendingMap(session.LocalFlowState)
localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse)
update := client.PendingAuthSession.UpdateOneID(session.ID).
SetIntent(strings.TrimSpace(intent)).
SetResolvedEmail(strings.TrimSpace(resolvedEmail)).
SetLocalFlowState(localFlowState)
if targetUserID != nil && *targetUserID > 0 {
update = update.SetTargetUserID(*targetUserID)
} else {
update = update.ClearTargetUserID()
}
return update.Save(ctx)
}
func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) {
if session == nil {
return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid")
}
if session.TargetUserID != nil && *session.TargetUserID > 0 {
return *session.TargetUserID, nil
}
email := strings.TrimSpace(session.ResolvedEmail)
if email == "" {
return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing")
}
userEntity, err := findUserByNormalizedEmail(ctx, client, email)
if err != nil {
if errors.Is(err, service.ErrUserNotFound) {
return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found")
}
return 0, err
}
return userEntity.ID, nil
}
func userNormalizedEmailPredicate(email string) predicate.User {
normalized := strings.ToLower(strings.TrimSpace(email))
if normalized == "" {
return dbuser.EmailEQ(email)
}
return predicate.User(func(s *entsql.Selector) {
s.Where(entsql.P(func(b *entsql.Builder) {
b.WriteString("LOWER(TRIM(").
Ident(s.C(dbuser.FieldEmail)).
WriteString(")) = ").
Arg(normalized)
}))
})
}
func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) {
if client == nil {
return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
matches, err := client.User.Query().
Where(userNormalizedEmailPredicate(email)).
Order(dbent.Asc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, err
}
if len(matches) == 0 {
return nil, service.ErrUserNotFound
}
if len(matches) > 1 {
return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
}
return matches[0], nil
}
func ensurePendingOAuthRegistrationIdentityAvailable(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) error {
if client == nil || session == nil {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil
}
return err
}
if identity == nil || identity.UserID <= 0 {
return nil
}
activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
if err != nil {
return err
}
if activeOwner != nil {
return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
return nil
}
func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string {
if session == nil {
return nil
}
switch strings.TrimSpace(session.ProviderType) {
case "oidc":
issuer := strings.TrimSpace(session.ProviderKey)
if issuer == "" {
issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
}
if issuer == "" {
return nil
}
return &issuer
default:
issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer")
if issuer == "" {
return nil
}
return &issuer
}
}
func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") {
return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID)
}
client := tx.Client()
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)),
authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)),
authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)),
).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if identity != nil {
if identity.UserID != userID {
activeOwner, err := findActiveUserByID(ctx, client, identity.UserID)
if err != nil {
return nil, err
}
if activeOwner != nil {
return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
return client.AuthIdentity.UpdateOneID(identity.ID).
SetUserID(userID).
Save(ctx)
}
return identity, nil
}
create := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(strings.TrimSpace(session.ProviderType)).
SetProviderKey(strings.TrimSpace(session.ProviderKey)).
SetProviderSubject(strings.TrimSpace(session.ProviderSubject)).
SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims))
if issuer := oauthIdentityIssuer(session); issuer != nil {
create = create.SetIssuer(strings.TrimSpace(*issuer))
}
return create.Save(ctx)
}
func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) {
client := tx.Client()
providerType := strings.TrimSpace(session.ProviderType)
providerKey := strings.TrimSpace(session.ProviderKey)
providerSubject := strings.TrimSpace(session.ProviderSubject)
providerKeys := wechatCompatibleProviderKeys(providerKey)
channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel"))
channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id"))
channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject"))
metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims)
identityRecords, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyIn(providerKeys...),
authidentity.ProviderSubjectEQ(providerSubject),
).
All(ctx)
if err != nil {
return nil, err
}
identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(ctx, client, identityRecords, userID, providerKey)
if err != nil {
return nil, err
}
var legacyOpenIDIdentity *dbent.AuthIdentity
if channelSubject != "" && channelSubject != providerSubject {
legacyOpenIDRecords, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderKeyIn(providerKeys...),
authidentity.ProviderSubjectEQ(channelSubject),
).
All(ctx)
if err != nil {
return nil, err
}
legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(ctx, client, legacyOpenIDRecords, userID, providerKey)
if err != nil {
return nil, err
}
}
switch {
case identity != nil:
update := client.AuthIdentity.UpdateOneID(identity.ID).
SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata))
if identity.UserID != userID {
update = update.SetUserID(userID)
}
if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey {
update = update.SetProviderKey(providerKey)
}
if issuer := oauthIdentityIssuer(session); issuer != nil {
update = update.SetIssuer(strings.TrimSpace(*issuer))
}
identity, err = update.Save(ctx)
if err != nil {
return nil, err
}
case legacyOpenIDIdentity != nil:
update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID).
SetProviderKey(providerKey).
SetProviderSubject(providerSubject).
SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata))
if issuer := oauthIdentityIssuer(session); issuer != nil {
update = update.SetIssuer(strings.TrimSpace(*issuer))
}
identity, err = update.Save(ctx)
if err != nil {
return nil, err
}
default:
create := client.AuthIdentity.Create().
SetUserID(userID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetProviderSubject(providerSubject).
SetMetadata(metadata)
if issuer := oauthIdentityIssuer(session); issuer != nil {
create = create.SetIssuer(strings.TrimSpace(*issuer))
}
identity, err = create.Save(ctx)
if err != nil {
return nil, err
}
}
if channel == "" || channelAppID == "" || channelSubject == "" {
return identity, nil
}
channelRecords, err := client.AuthIdentityChannel.Query().
Where(
authidentitychannel.ProviderTypeEQ(providerType),
authidentitychannel.ProviderKeyIn(providerKeys...),
authidentitychannel.ChannelEQ(channel),
authidentitychannel.ChannelAppIDEQ(channelAppID),
authidentitychannel.ChannelSubjectEQ(channelSubject),
).
WithIdentity().
All(ctx)
if err != nil {
return nil, err
}
channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(ctx, client, channelRecords, userID, providerKey)
if err != nil {
return nil, err
}
channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata)
if channelRecord == nil {
if _, err := client.AuthIdentityChannel.Create().
SetIdentityID(identity.ID).
SetProviderType(providerType).
SetProviderKey(providerKey).
SetChannel(channel).
SetChannelAppID(channelAppID).
SetChannelSubject(channelSubject).
SetMetadata(channelMetadata).
Save(ctx); err != nil {
return nil, err
}
return identity, nil
}
updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID).
SetIdentityID(identity.ID).
SetMetadata(channelMetadata)
if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey {
updateChannel = updateChannel.SetProviderKey(providerKey)
}
_, err = updateChannel.Save(ctx)
if err != nil {
return nil, err
}
return identity, nil
}
func chooseWeChatIdentityForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) {
var preferred *dbent.AuthIdentity
var fallback *dbent.AuthIdentity
hasCanonicalKey := false
for _, record := range records {
if record == nil {
continue
}
if record.UserID != userID {
activeOwner, err := findActiveUserByID(ctx, client, record.UserID)
if err != nil {
return nil, false, err
}
if activeOwner != nil {
return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user")
}
}
if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
hasCanonicalKey = true
if preferred == nil {
preferred = record
}
continue
}
if fallback == nil {
fallback = record
}
}
if preferred != nil {
return preferred, hasCanonicalKey, nil
}
return fallback, hasCanonicalKey, nil
}
func chooseWeChatChannelForUser(ctx context.Context, client *dbent.Client, records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) {
var preferred *dbent.AuthIdentityChannel
var fallback *dbent.AuthIdentityChannel
hasCanonicalKey := false
for _, record := range records {
if record == nil {
continue
}
if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID {
activeOwner, err := findActiveUserByID(ctx, client, record.Edges.Identity.UserID)
if err != nil {
return nil, false, err
}
if activeOwner != nil {
return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user")
}
}
if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) {
hasCanonicalKey = true
if preferred == nil {
preferred = record
}
continue
}
if fallback == nil {
fallback = record
}
}
if preferred != nil {
return preferred, hasCanonicalKey, nil
}
return fallback, hasCanonicalKey, nil
}
func findActiveUserByID(ctx context.Context, client *dbent.Client, userID int64) (*dbent.User, error) {
if client == nil || userID <= 0 {
return nil, nil
}
userEntity, err := client.User.Get(ctx, userID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err)
}
if !strings.EqualFold(strings.TrimSpace(userEntity.Status), service.StatusActive) {
return nil, service.ErrUserNotActive
}
return userEntity, nil
}
func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any {
if channel == nil {
return map[string]any{}
}
return cloneOAuthMetadata(channel.Metadata)
}
func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool {
if session == nil || decision == nil {
return false
}
switch strings.ToLower(strings.TrimSpace(session.Intent)) {
case "bind_current_user", "login", "adopt_existing_user_by_email":
return true
default:
return decision.AdoptDisplayName || decision.AdoptAvatar
}
}
func shouldSkipAvatarAdoption(err error) bool {
return errors.Is(err, service.ErrAvatarInvalid) ||
errors.Is(err, service.ErrAvatarTooLarge) ||
errors.Is(err, service.ErrAvatarNotImage)
}
func applyPendingOAuthBinding(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
forceBind bool,
applyFirstBindDefaults bool,
) error {
if client == nil || session == nil {
return nil
}
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
return nil
}
if tx := dbent.TxFromContext(ctx); tx != nil {
return applyPendingOAuthBindingTx(ctx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults)
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := applyPendingOAuthBindingTx(txCtx, tx, authService, userService, session, decision, overrideUserID, forceBind, applyFirstBindDefaults); err != nil {
return err
}
return tx.Commit()
}
func applyPendingOAuthBindingTx(
ctx context.Context,
tx *dbent.Tx,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
forceBind bool,
applyFirstBindDefaults bool,
) error {
if tx == nil || session == nil {
return nil
}
if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) {
return nil
}
targetUserID := int64(0)
if overrideUserID != nil && *overrideUserID > 0 {
targetUserID = *overrideUserID
} else {
resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, tx.Client(), session)
if err != nil {
return err
}
targetUserID = resolvedUserID
}
adoptedDisplayName := ""
if decision != nil && decision.AdoptDisplayName {
adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name"))
}
adoptedAvatarURL := ""
if decision != nil && decision.AdoptAvatar {
adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url")
}
shouldAdoptAvatar := false
if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" {
if err := service.ValidateUserAvatar(adoptedAvatarURL); err == nil {
shouldAdoptAvatar = true
} else if !shouldSkipAvatarAdoption(err) {
return err
}
}
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
if err := tx.Client().User.UpdateOneID(targetUserID).
SetUsername(adoptedDisplayName).
Exec(ctx); err != nil {
return err
}
}
identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID)
if err != nil {
return err
}
metadata := cloneOAuthMetadata(identity.Metadata)
for key, value := range session.UpstreamIdentityClaims {
metadata[key] = value
}
if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" {
metadata["display_name"] = adoptedDisplayName
}
if shouldAdoptAvatar {
metadata["avatar_url"] = adoptedAvatarURL
}
updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata)
if issuer := oauthIdentityIssuer(session); issuer != nil {
updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer))
}
if _, err := updateIdentity.Save(ctx); err != nil {
return err
}
if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) {
if _, err := tx.Client().IdentityAdoptionDecision.Update().
Where(
identityadoptiondecision.IdentityIDEQ(identity.ID),
identityadoptiondecision.IDNEQ(decision.ID),
).
ClearIdentityID().
Save(ctx); err != nil {
return err
}
if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID).
SetIdentityID(identity.ID).
Save(ctx); err != nil {
return err
}
}
if applyFirstBindDefaults && authService != nil {
if err := authService.ApplyProviderDefaultSettingsOnFirstBind(ctx, targetUserID, session.ProviderType); err != nil {
return err
}
}
if shouldAdoptAvatar && userService != nil {
if _, err := userService.SetAvatar(ctx, targetUserID, adoptedAvatarURL); err != nil {
return err
}
}
return nil
}
func consumePendingOAuthBrowserSessionTx(
ctx context.Context,
tx *dbent.Tx,
session *dbent.PendingAuthSession,
) error {
if tx == nil || session == nil {
return service.ErrPendingAuthSessionNotFound
}
storedSession, err := tx.Client().PendingAuthSession.Get(ctx, session.ID)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrPendingAuthSessionNotFound
}
return err
}
now := time.Now().UTC()
if storedSession.ConsumedAt != nil {
return service.ErrPendingAuthSessionConsumed
}
if !storedSession.ExpiresAt.IsZero() && now.After(storedSession.ExpiresAt) {
return service.ErrPendingAuthSessionExpired
}
if strings.TrimSpace(storedSession.BrowserSessionKey) != "" &&
strings.TrimSpace(storedSession.BrowserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
return service.ErrPendingAuthBrowserMismatch
}
if _, err := tx.Client().PendingAuthSession.UpdateOneID(storedSession.ID).
SetConsumedAt(now).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx); err != nil {
return err
}
return nil
}
func applyPendingOAuthAdoptionAndConsumeSession(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
userID int64,
) error {
if client == nil {
return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
}
if session == nil || userID <= 0 {
return infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")
}
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(ctx, tx)
if err := applyPendingOAuthAdoption(txCtx, client, authService, userService, session, decision, &userID); err != nil {
return err
}
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
return err
}
return tx.Commit()
}
func applyPendingOAuthAdoption(
ctx context.Context,
client *dbent.Client,
authService *service.AuthService,
userService *service.UserService,
session *dbent.PendingAuthSession,
decision *dbent.IdentityAdoptionDecision,
overrideUserID *int64,
) error {
return applyPendingOAuthBinding(
ctx,
client,
authService,
userService,
session,
decision,
overrideUserID,
false,
strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user"),
)
}
func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) {
if len(payload) == 0 || len(upstream) == 0 {
return
}
displayName := pendingSessionStringValue(upstream, "suggested_display_name")
avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url")
if displayName != "" {
if _, exists := payload["suggested_display_name"]; !exists {
payload["suggested_display_name"] = displayName
}
}
if avatarURL != "" {
if _, exists := payload["suggested_avatar_url"]; !exists {
payload["suggested_avatar_url"] = avatarURL
}
}
if displayName != "" || avatarURL != "" {
payload["adoption_required"] = true
}
}
func pendingOAuthIdentityExistsForUser(
ctx context.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
userID int64,
) (bool, error) {
if client == nil || session == nil || userID <= 0 {
return false, nil
}
providerType := strings.TrimSpace(session.ProviderType)
providerKey := strings.TrimSpace(session.ProviderKey)
providerSubject := strings.TrimSpace(session.ProviderSubject)
if providerType == "" || providerSubject == "" {
return false, nil
}
query := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ(providerType),
authidentity.ProviderSubjectEQ(providerSubject),
authidentity.UserIDEQ(userID),
)
if strings.EqualFold(providerType, "wechat") {
query = query.Where(authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(providerKey)...))
} else if providerKey != "" {
query = query.Where(authidentity.ProviderKeyEQ(providerKey))
}
count, err := query.Count(ctx)
if err != nil {
return false, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err)
}
return count > 0, nil
}
func (h *AuthHandler) shouldSkipPendingOAuthAdoptionPrompt(
ctx context.Context,
session *dbent.PendingAuthSession,
payload map[string]any,
) (bool, error) {
if session == nil || len(payload) == 0 {
return false, nil
}
if !pendingOAuthCompletionCanIssueTokenPair(session, payload) {
return false, nil
}
if pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name") == "" &&
pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") == "" {
return false, nil
}
return pendingOAuthIdentityExistsForUser(ctx, h.entClient(), session, *session.TargetUserID)
}
func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) {
secureCookie := isRequestHTTPS(c)
clearCookies := func() {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
clearCookies()
return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
clearCookies()
return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch
}
svc, err := h.pendingIdentityService()
if err != nil {
clearCookies()
return nil, nil, clearCookies, err
}
session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
if err != nil {
clearCookies()
return nil, nil, clearCookies, err
}
return svc, session, clearCookies, nil
}
func (h *AuthHandler) consumePendingOAuthSessionOnLogout(c *gin.Context) {
if c == nil || c.Request == nil {
return
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
return
}
svc, err := h.pendingIdentityService()
if err != nil {
return
}
_, _ = svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
}
func clearOAuthLogoutCookies(c *gin.Context) {
secureCookie := isRequestHTTPS(c)
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
clearOAuthBindAccessTokenCookie(c, secureCookie)
clearCookie(c, linuxDoOAuthStateCookieName, secureCookie)
clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie)
clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie)
clearCookie(c, linuxDoOAuthIntentCookieName, secureCookie)
clearCookie(c, linuxDoOAuthBindUserCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthStateCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthVerifierCookie, secureCookie)
oidcClearCookie(c, oidcOAuthRedirectCookie, secureCookie)
oidcClearCookie(c, oidcOAuthNonceCookie, secureCookie)
oidcClearCookie(c, oidcOAuthIntentCookieName, secureCookie)
oidcClearCookie(c, oidcOAuthBindUserCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie)
wechatClearCookie(c, wechatOAuthBindUserCookieName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthStateName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthRedirect, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthContextName, secureCookie)
wechatPaymentClearCookie(c, wechatPaymentOAuthScope, secureCookie)
}
func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H {
completionResponse := normalizePendingOAuthCompletionResponse(mergePendingCompletionResponse(session, nil))
payload := gin.H{
"auth_result": "pending_session",
"provider": strings.TrimSpace(session.ProviderType),
"intent": strings.TrimSpace(session.Intent),
}
for key, value := range completionResponse {
payload[key] = value
}
if email := strings.TrimSpace(session.ResolvedEmail); email != "" {
payload["email"] = email
}
return payload
}
func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]any {
normalized := clonePendingMap(payload)
for _, key := range []string{"access_token", "refresh_token", "expires_in", "token_type"} {
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
switch step {
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
normalized["step"] = oauthPendingChoiceStep
}
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
normalized["adoption_required"] = true
}
if _, exists := normalized["adoption_required"]; !exists {
if _, hasChoiceFields := normalized["email_binding_required"]; hasChoiceFields {
normalized["adoption_required"] = true
}
}
return normalized
}
func pendingOAuthChoiceCompletionResponse(session *dbent.PendingAuthSession, email string) map[string]any {
response := mergePendingCompletionResponse(session, map[string]any{
"step": oauthPendingChoiceStep,
"adoption_required": true,
"force_email_on_signup": true,
"email_binding_required": true,
"existing_account_bindable": true,
})
if email = strings.TrimSpace(email); email != "" {
response["email"] = email
response["resolved_email"] = email
}
return response
}
func (h *AuthHandler) transitionPendingOAuthAccountToChoiceState(
c *gin.Context,
client *dbent.Client,
session *dbent.PendingAuthSession,
targetUser *dbent.User,
email string,
) (*dbent.PendingAuthSession, error) {
completionResponse := pendingOAuthChoiceCompletionResponse(session, email)
var targetUserID *int64
if targetUser != nil && targetUser.ID > 0 {
targetUserID = &targetUser.ID
}
session, err := updatePendingOAuthSessionProgress(
c.Request.Context(),
client,
session,
strings.TrimSpace(session.Intent),
email,
targetUserID,
completionResponse,
)
if err != nil {
return nil, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)
}
return session, nil
}
func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) {
c.JSON(http.StatusOK, gin.H{
"access_token": tokenPair.AccessToken,
"refresh_token": tokenPair.RefreshToken,
"expires_in": tokenPair.ExpiresIn,
"token_type": "Bearer",
})
}
func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
var req bindPendingOAuthLoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password)
if err != nil {
response.ErrorFrom(c, err)
return
}
if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID {
response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user"))
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), user); err != nil {
response.ErrorFrom(c, err)
return
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
response.ErrorFrom(c, err)
return
}
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
tempToken, err := h.totpService.CreatePendingOAuthBindLoginSession(
c.Request.Context(),
user.ID,
user.Email,
session.SessionToken,
session.BrowserSessionKey,
)
if err != nil {
response.InternalError(c, "Failed to create 2FA session")
return
}
response.Success(c, TotpLoginResponse{
Requires2FA: true,
TempToken: tempToken,
UserEmailMasked: service.MaskEmail(user.Email),
})
return
}
if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil {
respondPendingOAuthBindingApplyError(c, err)
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
response.InternalError(c, "Failed to generate token pair")
return
}
if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
func respondPendingOAuthBindingApplyError(c *gin.Context, err error) {
if code := infraerrors.Code(err); code >= http.StatusBadRequest && code < http.StatusInternalServerError {
response.ErrorFrom(c, err)
return
}
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
}
func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) {
var req createPendingOAuthAccountRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
_, session, clearCookies, err := readPendingOAuthBrowserSession(c, h)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
response.ErrorFrom(c, err)
return
}
if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) {
response.BadRequest(c, "Pending oauth session provider mismatch")
return
}
client := h.entClient()
if client == nil {
response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
return
}
email := strings.TrimSpace(strings.ToLower(req.Email))
existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email)
if err != nil {
switch {
case errors.Is(err, service.ErrUserNotFound):
existingUser = nil
case infraerrors.Code(err) >= http.StatusBadRequest && infraerrors.Code(err) < http.StatusInternalServerError:
response.ErrorFrom(c, err)
return
default:
response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable"))
return
}
}
if existingUser != nil {
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.RegisterOAuthEmailAccount(
c.Request.Context(),
email,
req.Password,
strings.TrimSpace(req.VerifyCode),
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
)
if err != nil {
if errors.Is(err, service.ErrEmailExists) {
existingUser, lookupErr := findUserByNormalizedEmail(c.Request.Context(), client, email)
if lookupErr != nil {
response.ErrorFrom(c, lookupErr)
return
}
session, err = h.transitionPendingOAuthAccountToChoiceState(c, client, session, existingUser, email)
if err != nil {
response.ErrorFrom(c, err)
return
}
c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session))
return
}
response.ErrorFrom(c, err)
return
}
rollbackCreatedUser := func(originalErr error) bool {
if user == nil || user.ID <= 0 {
return false
}
if rollbackErr := h.authService.RollbackOAuthEmailAccountCreation(
c.Request.Context(),
user.ID,
strings.TrimSpace(req.InvitationCode),
); rollbackErr != nil {
response.ErrorFrom(c, infraerrors.InternalServer(
"PENDING_AUTH_ACCOUNT_ROLLBACK_FAILED",
"failed to rollback pending oauth account creation",
).WithCause(fmt.Errorf("original error: %w; rollback error: %v", originalErr, rollbackErr)))
return true
}
user = nil
return false
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision())
if err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, err)
return
}
tx, err := client.Tx(c.Request.Context())
if err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
defer func() { _ = tx.Rollback() }()
txCtx := dbent.NewTxContext(c.Request.Context(), tx)
if err := applyPendingOAuthBinding(txCtx, client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
respondPendingOAuthBindingApplyError(c, err)
return
}
if err := h.authService.FinalizeOAuthEmailAccount(
txCtx,
user,
strings.TrimSpace(req.InvitationCode),
strings.TrimSpace(session.ProviderType),
); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, err)
return
}
if err := consumePendingOAuthBrowserSessionTx(txCtx, tx, session); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
clearCookies()
response.ErrorFrom(c, err)
return
}
if pendingOAuthCreateAccountPreCommitHook != nil {
if err := pendingOAuthCreateAccountPreCommitHook(txCtx, session); err != nil {
_ = tx.Rollback()
if rollbackCreatedUser(err) {
return
}
respondPendingOAuthBindingApplyError(c, err)
return
}
}
if err := tx.Commit(); err != nil {
if rollbackCreatedUser(err) {
return
}
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err))
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload.
// POST /api/v1/auth/oauth/pending/exchange
func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
secureCookie := isRequestHTTPS(c)
clearCookies := func() {
clearOAuthPendingSessionCookie(c, secureCookie)
clearOAuthPendingBrowserCookie(c, secureCookie)
}
adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c)
if err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken, err := readOAuthPendingSessionCookie(c)
if err != nil || strings.TrimSpace(sessionToken) == "" {
clearCookies()
response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
return
}
browserSessionKey, err := readOAuthPendingBrowserCookie(c)
if err != nil || strings.TrimSpace(browserSessionKey) == "" {
clearCookies()
response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
return
}
svc, err := h.pendingIdentityService()
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
payload, ok := readCompletionResponse(session.LocalFlowState)
if !ok {
clearCookies()
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid"))
return
}
payload = normalizePendingOAuthCompletionResponse(payload)
if strings.TrimSpace(session.RedirectTo) != "" {
if _, exists := payload["redirect"]; !exists {
payload["redirect"] = session.RedirectTo
}
}
applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims)
canIssueTokenPair := pendingOAuthCompletionCanIssueTokenPair(session, payload)
var loginUser *service.User
if canIssueTokenPair {
loginUser, err = h.userService.GetByID(c.Request.Context(), *session.TargetUserID)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := ensureLoginUserActive(loginUser); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if err := h.ensureBackendModeAllowsUser(c.Request.Context(), loginUser); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
}
skipAdoptionPrompt, err := h.shouldSkipPendingOAuthAdoptionPrompt(c.Request.Context(), session, payload)
if err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if skipAdoptionPrompt {
delete(payload, "adoption_required")
}
if pendingSessionWantsInvitation(payload) {
if adoptionDecision.hasDecision() {
decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision)
if err != nil {
response.ErrorFrom(c, err)
return
}
_ = decision
}
response.Success(c, payload)
return
}
if !adoptionDecision.hasDecision() {
adoptionRequired, _ := payload["adoption_required"].(bool)
if adoptionRequired {
response.Success(c, payload)
return
}
}
decisionReq := adoptionDecision
if !decisionReq.hasDecision() {
adoptDisplayName := false
adoptAvatar := false
decisionReq = oauthAdoptionDecisionRequest{
AdoptDisplayName: &adoptDisplayName,
AdoptAvatar: &adoptAvatar,
}
}
decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, decisionReq)
if err != nil {
response.ErrorFrom(c, err)
return
}
if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil {
response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err))
return
}
if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil {
clearCookies()
response.ErrorFrom(c, err)
return
}
if canIssueTokenPair {
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), loginUser, "")
if err != nil {
clearCookies()
response.InternalError(c, "Failed to generate token pair")
return
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), loginUser.ID)
payload["access_token"] = tokenPair.AccessToken
payload["refresh_token"] = tokenPair.RefreshToken
payload["expires_in"] = tokenPair.ExpiresIn
payload["token_type"] = "Bearer"
}
clearCookies()
response.Success(c, payload)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment