From 6c89d8d35cfc3fabd38fdfc8463203ee42ed1c2d Mon Sep 17 00:00:00 2001 From: fjl5 Date: Mon, 13 Apr 2026 17:30:49 +0800 Subject: [PATCH 001/265] =?UTF-8?q?add=20prompt=5Fcache=5Fkey=20injection?= =?UTF-8?q?=20for=20messages=E2=86=92responses?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/openai_gateway_messages.go | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index a72b9bbf..2a0a72eb 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -121,6 +121,28 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( } } + // For API key accounts (including OpenAI-compatible upstream gateways), + // ensure promptCacheKey is also propagated via the request body so that + // upstreams using the Responses API can derive a stable session identifier + // from prompt_cache_key. This makes our Anthropic /v1/messages compatibility + // path behave more like a native Responses client. + if account.Type == AccountTypeAPIKey { + if trimmedKey := strings.TrimSpace(promptCacheKey); trimmedKey != "" { + var reqBody map[string]any + if err := json.Unmarshal(responsesBody, &reqBody); err != nil { + return nil, fmt.Errorf("unmarshal for prompt cache key injection: %w", err) + } + if existing, ok := reqBody["prompt_cache_key"].(string); !ok || strings.TrimSpace(existing) == "" { + reqBody["prompt_cache_key"] = trimmedKey + updated, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after prompt cache key injection: %w", err) + } + responsesBody = updated + } + } + } + // 5. Get access token token, _, err := s.GetAccessToken(ctx, account) if err != nil { -- GitLab From 10699eeb34e0e9b602a50d43a1be434d4751d5af Mon Sep 17 00:00:00 2001 From: erio Date: Thu, 16 Apr 2026 01:53:22 +0800 Subject: [PATCH 002/265] refactor: extract ReadUpstreamResponseBody to deduplicate upstream response read + too-large error handling Consolidates 9 call sites of resolveUpstreamResponseReadLimit + readUpstreamResponseBodyLimited + ErrUpstreamResponseBodyTooLarge error handling into a single ReadUpstreamResponseBody function with TooLargeWriter callback for API-format-specific error responses (Anthropic, OpenAI, countTokens). --- backend/internal/service/gateway_service.go | 74 +++++-------------- .../service/gemini_messages_compat_service.go | 12 +-- .../service/openai_gateway_service.go | 24 +----- .../service/upstream_response_limit.go | 43 +++++++++++ .../service/upstream_response_limit_test.go | 43 +++++++++++ 5 files changed, 107 insertions(+), 89 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index c65e828a..4b4fc0bf 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5120,19 +5120,8 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) } - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -5498,19 +5487,8 @@ func (s *GatewayService) handleBedrockNonStreamingResponse( c *gin.Context, account *Account, ) (*ClaudeUsage, error) { - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -7175,19 +7153,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -8300,16 +8267,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 读取响应体 - maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) - respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + countTokensTooLarge := func(c *gin.Context) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + } + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) _ = resp.Body.Close() if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - return err + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") } - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } @@ -8323,15 +8289,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) if retryErr == nil { resp = retryResp - respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + respBody, err = ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) _ = resp.Body.Close() if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - return err + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") } - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } } @@ -8426,16 +8389,15 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex return fmt.Errorf("upstream request failed: %w", err) } - maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) - respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + countTokensTooLarge := func(c *gin.Context) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + } + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) _ = resp.Body.Close() if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - return err + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") } - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 5a9490f3..7a24071b 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2424,18 +2424,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") } - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ef97daad..064191bd 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -3010,18 +3010,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( resp *http.Response, c *gin.Context, ) (*OpenAIUsage, error) { - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } @@ -3919,18 +3909,8 @@ func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { - maxBytes := resolveUpstreamResponseReadLimit(s.cfg) - body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream response too large", - }, - }) - } return nil, err } diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go index aecf69a3..a0444d52 100644 --- a/backend/internal/service/upstream_response_limit.go +++ b/backend/internal/service/upstream_response_limit.go @@ -4,8 +4,10 @@ import ( "errors" "fmt" "io" + "net/http" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" ) var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") @@ -36,3 +38,44 @@ func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, } return body, nil } + +// TooLargeWriter 在响应超限时向客户端写格式化的错误响应。 +type TooLargeWriter func(c *gin.Context) + +// ReadUpstreamResponseBody 读取上游非流式响应体。 +// 超限时自动记录 ops error 并调用 onTooLarge 向客户端写错误。 +func ReadUpstreamResponseBody(reader io.Reader, cfg *config.Config, c *gin.Context, onTooLarge TooLargeWriter) ([]byte, error) { + maxBytes := resolveUpstreamResponseReadLimit(cfg) + body, err := readUpstreamResponseBodyLimited(reader, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + if onTooLarge != nil { + onTooLarge(c) + } + } + return nil, err + } + return body, nil +} + +// anthropicTooLargeError 以 Anthropic Messages API 格式写入超限错误。 +func anthropicTooLargeError(c *gin.Context) { + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) +} + +// openAITooLargeError 以 OpenAI / Gemini 格式写入超限错误。 +func openAITooLargeError(c *gin.Context) { + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) +} diff --git a/backend/internal/service/upstream_response_limit_test.go b/backend/internal/service/upstream_response_limit_test.go index b9e5cc6d..09283189 100644 --- a/backend/internal/service/upstream_response_limit_test.go +++ b/backend/internal/service/upstream_response_limit_test.go @@ -4,8 +4,10 @@ import ( "bytes" "errors" "testing" + "testing/iotest" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -35,3 +37,44 @@ func TestReadUpstreamResponseBodyLimited(t *testing.T) { require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) }) } + +func TestReadUpstreamResponseBody(t *testing.T) { + t.Run("within limit", func(t *testing.T) { + body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("ok")), nil, nil, nil) + require.NoError(t, err) + require.Equal(t, []byte("ok"), body) + }) + + t.Run("exceeds limit calls onTooLarge", func(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UpstreamResponseReadMaxBytes = 3 + + called := false + onTooLarge := func(_ *gin.Context) { called = true } + + body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, onTooLarge) + require.Nil(t, body) + require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + require.True(t, called) + }) + + t.Run("nil onTooLarge does not panic", func(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UpstreamResponseReadMaxBytes = 3 + + body, err := ReadUpstreamResponseBody(bytes.NewReader([]byte("toolong")), cfg, nil, nil) + require.Nil(t, body) + require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + }) + + t.Run("io error does not call onTooLarge", func(t *testing.T) { + called := false + onTooLarge := func(_ *gin.Context) { called = true } + + body, err := ReadUpstreamResponseBody(iotest.ErrReader(errors.New("disk failure")), nil, nil, onTooLarge) + require.Nil(t, body) + require.Error(t, err) + require.False(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge)) + require.False(t, called) + }) +} -- GitLab From 3944b3d216bbcc10472fccdc4f6606fa2b85106e Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Thu, 16 Apr 2026 02:01:50 +0000 Subject: [PATCH 003/265] fix: preserve openai ws flags in scheduler cache --- .../internal/repository/scheduler_cache.go | 7 +++ .../repository/scheduler_cache_unit_test.go | 33 ++++++++++ ...enai_account_scheduler_ws_snapshot_test.go | 62 +++++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 backend/internal/repository/scheduler_cache_unit_test.go create mode 100644 backend/internal/service/openai_account_scheduler_ws_snapshot_test.go diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index e9be8c7a..add0e501 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -426,6 +426,13 @@ func filterSchedulerExtra(extra map[string]any) map[string]any { "window_cost_sticky_reserve", "max_sessions", "session_idle_timeout_minutes", + "openai_oauth_responses_websockets_v2_enabled", + "openai_oauth_responses_websockets_v2_mode", + "openai_apikey_responses_websockets_v2_enabled", + "openai_apikey_responses_websockets_v2_mode", + "responses_websockets_v2_enabled", + "openai_ws_enabled", + "openai_ws_force_http", } filtered := make(map[string]any) for _, key := range keys { diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go new file mode 100644 index 00000000..bcfd0e7a --- /dev/null +++ b/backend/internal/repository/scheduler_cache_unit_test.go @@ -0,0 +1,33 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) { + account := service.Account{ + ID: 42, + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + "openai_ws_force_http": true, + "mixed_scheduling": true, + "unused_large_field": "drop-me", + }, + } + + got := buildSchedulerMetadataAccount(account) + + require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"]) + require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"]) + require.Equal(t, true, got.Extra["openai_ws_force_http"]) + require.Equal(t, true, got.Extra["mixed_scheduling"]) + require.Nil(t, got.Extra["unused_large_field"]) +} diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go new file mode 100644 index 00000000..c5de8203 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go @@ -0,0 +1,62 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapshotFlags(t *testing.T) { + ctx := context.Background() + groupID := int64(10105) + account := &Account{ + ID: 35001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 10, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + + snapshotCache := &openAISnapshotCacheStub{ + snapshotAccounts: []*Account{account}, + accountsByID: map[int64]*Account{account.ID: account}, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, + cache: &stubGatewayCache{}, + cfg: cfg, + schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_ws_passthrough", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) +} -- GitLab From 836092a6665f3f8f50ef8f4b055b89d73c748310 Mon Sep 17 00:00:00 2001 From: KnowSky404 Date: Thu, 16 Apr 2026 02:13:04 +0000 Subject: [PATCH 004/265] fix: restore ctx pool ws mode option in account ui --- frontend/src/components/account/BulkEditAccountModal.vue | 2 ++ frontend/src/components/account/CreateAccountModal.vue | 5 ++--- frontend/src/components/account/EditAccountModal.vue | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 5461015b..13c30cf9 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -921,6 +921,7 @@ import { getPresetMappingsByPlatform } from '@/composables/useModelWhitelist' import { + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, @@ -1069,6 +1070,7 @@ const isOpenAIModelRestrictionDisabled = computed( const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, + { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openAIWSModeConcurrencyHintKey = computed(() => diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 432fa080..2130c9ab 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2932,7 +2932,7 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - // OPENAI_WS_MODE_CTX_POOL, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, @@ -3180,8 +3180,7 @@ const geminiSelectedTier = computed(() => { const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 - // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index e5065d7a..1da32e2c 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1858,7 +1858,7 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - // OPENAI_WS_MODE_CTX_POOL, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, @@ -2020,8 +2020,7 @@ const editWeeklyResetHour = ref(null) const editResetTimezone = ref(null) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - // TODO: ctx_pool 选项暂时隐藏,待测试完成后恢复 - // { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ -- GitLab From a55ead5ea860c2517c625fbbe6cb2f1823382954 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Apr 2026 16:42:40 +0800 Subject: [PATCH 005/265] chore: remove empty dir Antigravity-Manager --- Antigravity-Manager | 1 - 1 file changed, 1 deletion(-) delete mode 160000 Antigravity-Manager diff --git a/Antigravity-Manager b/Antigravity-Manager deleted file mode 160000 index a9d96bd5..00000000 --- a/Antigravity-Manager +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a9d96bd54978c22d3033830debfe77aeeeee2500 -- GitLab From 7ea8e7e6675c3e5b4c5c73578136420c76a4ee64 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 16 Apr 2026 17:19:32 +0800 Subject: [PATCH 006/265] chore: update sponsors --- README.md | 5 +++++ README_CN.md | 5 +++++ README_JA.md | 5 +++++ assets/partners/logos/bmoplus.jpg | Bin 0 -> 7996 bytes 4 files changed, 15 insertions(+) create mode 100644 assets/partners/logos/bmoplus.jpg diff --git a/README.md b/README.md index c2715eae..74ab9af2 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via this link, you'll receive an extra 10% bonus credit on your first top-up! + +bmoplus +Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups, users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF) + + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index 0ace1f77..c701372c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -90,6 +90,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过此链接注册,首次充值可额外获得 10% 赠送额度! + +bmoplus +感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! + + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index d74ca9ce..0d4db616 100644 --- a/README_JA.md +++ b/README_JA.md @@ -90,6 +90,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:こちらのリンクから登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント! + +bmoplus +本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます! + + ## エコシステム diff --git a/assets/partners/logos/bmoplus.jpg b/assets/partners/logos/bmoplus.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1a9b4d8b7b2d626d860f327a49fe0bf3ade711fb GIT binary patch literal 7996 zcmb7o1yr0tvhLs>7-Z1k&fpTg$y>i$ct^Ish300tJyepv#b!@&SxvEZptvqvdea zGOycfnABnoDwcYd>=!B}IF&h8&UFI3x?@s4$Dit|(QYrc7Ro?`skB=cf3<~i7S`ii z^iM`JT^=QdYSdk(RAqZd?Nx5kwh8!1?0B4(MY?y7dQB^N8|+&D$WdeC+R9_V zW8>tSDd|@^{Fy!`Rbuf(SgoD(_Z^1~Dt$VA>L2Gj-ZkZ#GYu>j#WuM|P2SuG*>5Cx zY>$^bC_k1vK<2nyYpqy0xi0nuoW@YO7*-33FTIoZEtkqI7j+Z(7?x0=X1dOqTTYf< z3EVZbvpVANE-uG%2)KICOPeCl|^I{&G^K}hwaW8mrCn%x+0al5t$XlHZ^qVoe^ zW^sc3O-X_mN6WSz2d zRw^`i{u>4V3rPeTfGk)5SXdZ1I5=3uf5agSEC3D&kBx)H!HJ7Ufxt%2E+(#GOi3l` zpEHVm#<2o>}I*`-rvmeQ64QUhgw^Re6AmZz|Vky_iI2{KVN*cKC3`*I!wrhs@!mD@;gN zX}Okx5Ytodm_&R0q=^-+xUaSqb}i9xzx3FY?dqmym2Ft@cWxAz7=Oz&u2UEVuxGN=)-7KakC}uIUSfC z8Knp^3v1Q{fpsEu`o$9pQ{AJ+LD}>JJa5pHlNnu*);LveQPJ3lrAtn1ueeUUib^dY zwmiC~juKqI9mIa9Zfi30LYR{Lw%@q8e_l(uEx9mQl?||gFKj-!cBx4&9Aa$I`kV0xiMX(Lb^V6ao!FMP=}XSK4Var?ty zK#ij6mzbWsdOQutMY}|S=xEEy6*^-9if=;J{E5=U^7I6rp?u51W6&(wd13p>{VW?E zH7fiQ-NEz(wnvXz+BF{D7q)|<)543_+^>^8MlkPFn@^TL_RfBG>ten5u|?)KW>KBL zU9+L};8i}U9{d`8PCHPl)xG*u3qfDTQuLtcf+_jSmQkfwL#2F?f4>Uy)8P;VJK>$Opg!r7BvSqaL4!I`;HKA zPn-K-{*2K2!@d93cLQ1u3sq&f4MXmrVpsY@%w$&VQ2tG*^hgJbtLm-BtHu4tW$REV{)$tUaMoZr(eq zZ_ib0Bqpq0U}tp5Sbx+Fh)TcY?sHnp);ElBMxsEkKglA0UdSr_{bj`Y-WG9gJ$>k1 zbnn(;^CO?|N@afGDsRt;NH0lL7k)W7r~GZ$3mUSx{afBaI_ze9AOw+>~=r(R6W2~T-H~DTT3@T92p)7 zDAh%;)4Xwn#)g1DrHR^DUA0@BI#c` zmPGsm7e#)nu2X${%0aj1pGbAT;7nro=QFBxgrhfz3oozp_Lj3-QBrQ8Kf^wS<6Xmv zKS6%J{Sxco{8YtwlktlVj@{RAbX@hkBwk!k8CWBPNgLe(S<$)lV~fLS8b>)=z^Z zUf_0m9;qtXecvapt{qUCSUp_g@GzcSmt9Q**0+jQCw5QCG4D*K&;;}r2M4#eTci;U zgG=g8EHt&7Y6&f4Tx+L0F0$<-D@=7YCB!W^OI_OIKEp0)P3U-Prq2-V(~W@dEFxvq zw(pB%`#Ng($j6Y5eyil$Wdf}K@n0E zmAUj@IcG`&@!wbx91N)9t2T##2isnPq!_)jS6;kkk}iUg)M4vqGStGgY9F{@(jHL+_(g2V|RZa;H9}L~ggsE7;ZcKjC+Vycw8~do#p4RJDeD z64BLc7%lrimCv|d8Y>b=*tmWwyE>M+>ygeZJ$T5Qx#$}LT{9=NP(q<80un63KOzJG z1H__$Wf#LH=TJ3)!{K!HkI!K8u@~iqswL}zwR!h(bubte;(Qzgok5Btp zrgfEVIo;7M)zRGp&s$5=-r}iKacSAXo$Flm`1p^$$}w(%R+z6a@0GqCtu?)VKj%a9g4*f8ZWDMP9;GoC0CbxWICCwwZg z;P8TEj$&H2VrU-n+GhVIEtX?8VrUFPl3qLH)>PfAGxvVvsZuJe!#?T#v;&*|UqHb~d>T!PW&;Oy)FL<2`%@fZ%jF+bCeEfHybKUR zLY{ib=(1sMuUyFmyJnPrl!lvJg+e6rE~5(_chf$%=AIUXCP7JihkKh;(01JU!|PvT zJ5BX%O0n!4-mz$7m^I!On%g?lW%^N)7u16(ilkRSBr)W+!H6FIT-EsWPviMhlktBv z)b#hSSXg2UeCuV z@{t%J@prGxr*fL($m_I#4O`Ru&Rp}yDR1)lSL&!^%U8?3M?Pq-lJYvD@nnOjtk93m zs=A+I@Q<82Ozf5Bi8<{pLzlP-Op(pHJ(OpTNIaV;b!%2f=7u&t#kKJMq8f;#_Qbp) zMY<5g-1a0JJR>^+f%!Nset)j^G@>QzS^TuElVYu$y?`N@PC$iy$*bMe-8T57%GJwDBnmAFAr_U*=ES04k>V`xF_}8(!_kmO*)=t4gTzG5vKu1UPuue;t-!fE4Ul zVyY(2{yA0TZ1HEtT~1TyqALG{Xd)EP%r2j(!7l(vF6h9qM_2U3HBh51(-)nPxig}~ zCqZ&{Zi*CtLtzt$lv~Oo7co3)6jYPpsp2|;Dz~2ivzOMU_Lll#`Ym-dEo~$Iv%ZA{ zIO1^7dQ)DD5mFy8f}Uf3q^83tqQDh0JL`ZE*Z*1^%mB zFeu$v7Zz7U-FL`x1kxtxxmE#JMAeg_EoX+<;#S{%SJK9KANZAe+&GYC z@G$0Ew()7G3_d?f#74Ib9%Jye0mgAicD&TCv47 zb&i4%2SmNMz2(g5Cxh`Cte9#di*BjSXkZ0}& z;}P?ti4JB*_UA|#20||mq$U~I67#bf)>mTY)5v)!4j(I|A@=-Sk|V>wluSPKx5>9y z;R%^O!29pwV~pGwuaMLShWqD0q>tt`aMdt7)B{WU$jmttwACt5; z^|8Jo<(r?;>YCi(P+@Bz#aL^D8jW@L1bRB(x$hnJxCVJW``qVRZj|sJ#cPrTm+N_v673iFll3o(;Al+cM`ks*9+IJDjjErXGBpJI(Qi(N^7D= zNwa1-rQK;;2+?^(uLfw5?Cx{h)30a)QPi72ig6`a!nimD4G9EAnr1&-u829%&rr;a zs4wKrRamG3f=N~*5Na_~ElfU3bW$v5e?u4t?%drpq3YTUq=I*FK3V&=Yd{*be!pBU0ep_`6qk)5ipyb~io>=NqcC8H0wUuOR<#P1&XSlNaR+XqJZ^gTU7W_= z?{Q%WKug)4luPW_6RO5(Bc;x0rp-C~$FDWIHF6k9IWY2h);59Hw6*x$jJv(z%(cHk z$<^35nmFnS1oJ3B*qpL@20P6SjBp!SE#2(n*qece4YzP_?WyRaiTJbv_;-k(_fm z3-%dANOBr8i^cF;fxv|jIUeflbyPW3be|~}e`U2iN^bNTM3=gN)SN4Ze?0Vynm>>f zkwkUA!;>qRDKw;(_gy6mf%p2Q-fj&@A@%`sXjcz`C3e{9pQj=>H&a z5nKZIKWMng%2yWt`pCAW%(cRJcR9{=IVRXM$~E)%l`CG-7G?TsV<+ba>z~OQi{nB+ z;}M=)jply-F-dfECd|rOD&)y5JZ(F*)huA1o0*8(#})1mmt)Ah|7&xeRWD(LoLQnF zmBoANmIT?N7-z9RT-x%6Rzkka|6n$u9GjRd9?5^;{+aqeb!f3*grYMTSa<{k=+)Xk z;Mt#IA*Nc@1+5lSXXGmJIU87_#^--)#UCV(_4`0U2S?6oaUoJJAji=K2P8c5z6GmW zA@b6j-Is3l)tkw7Lmx-`wKv_u)@Hh^093U4FTuI zd76B2)bn!`R!AO$Ww(8^M6NM)f~I6|PyDj&NR}@qAj#6t{FnU60N!y&H(7ea{h^eFY+j?$$((xN-0L z$40(V8>Pgk0mW z&J4;a+c!mV^(2+}5Q72!@elaLzuq|x4w$u(^@UImO(^uo-E(%PCDuhIlurG|4He)Z z(x@~J9X7@MRYf!@gLgN%*};!vl02LZBb$->91Z=d)WAPisL(?iEX+SoXn?;EK*fX{ zdN4C~&N=)0J^w&%tlNz;ufm=f>6~m-Y&cz!JjnZPyq?Z)9ixt%PogzwN;?$mqog#CjIcJLbSRN033ikgmt8_o(&_imC3tT`diQbWto^ zC$8DWB%JH2gHdj*o=4dP%_L>7u~*x+n%3>3vT5erTc}3)ec7~;$=sag3IpzG>G5M% zJB21*Tii9Qhogs`kUC7cr>Ua|3L@Elh^5YwI*vPtvu-OAGVJAxBcX{IG4My)t1_$* zbcP$<<-@$`*d0FMG)0#mX?=Vp(uMwXGi~%1ElkV=>}%!=-4!qvbWfnpjtJe9|8jPi zKfCf57h%QcRCS$Ak+T_ZK!K6dzmO5uvyg93N@Ui>&M=%5Yq7Umu~MlS;)qCyfNfB| z(T69RH~J%FC;IW09l<3uZaceu`Fc%SX42^-OI}jl3H?CgUqcZi&$!1i>ua35s5mDY zqLeyGiJ_2KYzr1lT7KKuvs4+roEDi~jU2t2^{$Mzw0Hsa(UCfA4ZVJwqRS7J8{17ajc<(;9NZG8hnOv_CEJrEvd8EV|`~!mO z$$iYc=jq_Z%2sll_x7<~S?}ai!DrRFTP_Kk{`nE>k-l!M6P2$Rmsb;h%e-&oGgxz= zQalFH-2Mz65m6kLKhLY_-$~dRp?Lvx7FBSBw^LFF%zgc$&h1&nMS`;O@%u-8F*8krKC1?d`WZ~W^oBn zc*X&PfoN|Ym07No6NZ>eE!UYwwGq3t6UUttmyq*Uv>ur*EmCE7%t$d)Of`2s%gTKA z?q$dInJJ9Hj2~Hv?`qgUrFn%c2sC%yM>eSxsneLbZIU7Iz$rpEL6#IE;xmFov0`gy^RAOP&!3CL2X7u#KYQj~^vh%5LGPy}O8txKSF+4lx&E}2pp6`*w zY>p(ArE)r#3}Xfa`VIDvT0bHatl=D-U!S-M@XViWd^*BI$j;5c zVIMKz8#T)GbsYxd+T47mttfwY^84n?IV30n#F&Ye1YXMS-`&``=3ki0^0qQX3S}&YSqi!G*3mm z3HeU&p3JJD`30cDz|FQZu!%lK|2=>98_&Ryu`U=LxlIAmMj79&TsjnN)|dt3H&?p& zyuosY>Dx=_FXQBcaw;?SqaS_f{ri&#)y$mO zq-$TVeb?2Gp1^&>?T2@hxpvv^Xf1Fz!x~-o7KRcJ!RjUti86(|M(8TTk{yYc4ETb` zbYf3%LxTbD4v~t55CJK{>utO3w72)t2pD>KAp^|Rx=C1q986;Y$8p8}MU})B+daS) z%(ibuQU=ziVScu|&8}+DmU}Bom%AqgQ=5c;aM2c8uY~f`pU9QD z5K-oBBl_cE*tK9yx;*yyYQ<@AWyPP_?1#7`{=l(jE6M2r%Fc@Z{BpH`I?n6WLyj0d zL+nUd|DDmQE}5i7KF8j|B;StB1(Ze7yZxi_qMIvd6UW6TdIN+13BGs;dA;Zyw z&LH7sGaP(MsOGZ~^3$kSJ}dteAzN|7!G{i+b$<>~#M!zjGvXQp#Zyg&@U1L@_YoN< z?zc$$LXRbVr!$@pj~;%L>(=l?*1D<{I8WQKY_Oe1AbdA_Zh{Y=6w_^;0@g;(d0?$q z7LW7fxs!WWPFb)pyAVfGf+n%k&V^TLJw;^FQvm5qprF!=;7kPU5*!o-Og`IYjX7{V z#8!ud%@$cTBD%dIwAN+)ehZ!H^=l++c+c+A6@?`Gy4m4R*u!QqgRSS&@RPyd`9PG{ zNDNZFXwvd2SWJV!u`_LNp#0hC-;88^{wo$kx$W%Vxa8_qJcK`OcKZTwkBT>e`^?h6O4+}q ze@QE+%SpeD(OcA6?~~*d*?i@YaQvyvH!JpX>e(CZ0`ZDG2-Uu~Q5b?JHl5!xB7(hy z2M-qDrAmO|i+{(Ecjyaf#3BRweS5c|&Az5>d?o6?ZBpxE*6}Jx%J`}}ci(npL-++S z@&En^SPY%2*N+~0MLBKT_oPLH!1dP0K}5yhwJF}2_ Date: Thu, 16 Apr 2026 19:09:40 +0800 Subject: [PATCH 007/265] fix: fix outbox watermark context expiry and add in-batch group rebuild dedup Fixes #1691 - pollOutbox() reused a 10s context for SetOutboxWatermark after event processing could take much longer, causing "outbox watermark write failed: context deadline exceeded". The watermark never advanced so the same 200 events were reprocessed every poll cycle, spiking CPU. Now uses an independent 5s context with up to 3 retries (200ms apart). - When multiple Codex accounts sharing the same 21-22 groups are all rate-limited in quick succession, each account_changed event triggered redundant bucket rebuild attempts for the same groups. Introduce batchSeenKey{groupID, platform} and thread a seen map through the handler chain; rebuildBucketsForPlatform skips (group, platform) pairs already rebuilt within the same poll batch (~80% fewer rebuild calls in the 5-accounts-same-groups scenario). Co-Authored-By: Claude Sonnet 4.6 (1M context) --- .../service/scheduler_snapshot_service.go | 82 +++++++++++++------ 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index d1330abb..5ead45bc 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -20,6 +20,14 @@ var ( const outboxEventTimeout = 2 * time.Minute +// batchSeenKey tracks which (groupID, platform) bucket sets have already been +// rebuilt within a single pollOutbox call, to avoid redundant work when multiple +// account_changed events share the same groups. +type batchSeenKey struct { + groupID int64 + platform string +} + type SchedulerSnapshotService struct { cache SchedulerCache outboxRepo SchedulerOutboxRepository @@ -244,9 +252,10 @@ func (s *SchedulerSnapshotService) pollOutbox() { } watermarkForCheck := watermark + seen := make(map[batchSeenKey]struct{}) for _, event := range events { eventCtx, cancel := context.WithTimeout(context.Background(), outboxEventTimeout) - err := s.handleOutboxEvent(eventCtx, event) + err := s.handleOutboxEvent(eventCtx, event, seen) cancel() if err != nil { logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err) @@ -255,8 +264,20 @@ func (s *SchedulerSnapshotService) pollOutbox() { } lastID := events[len(events)-1].ID - if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil { - logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", err) + wmCtx, wmCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer wmCancel() + var wmErr error + for i := range 3 { + wmErr = s.cache.SetOutboxWatermark(wmCtx, lastID) + if wmErr == nil { + break + } + if i < 2 { + time.Sleep(200 * time.Millisecond) + } + } + if wmErr != nil { + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", wmErr) } else { watermarkForCheck = lastID } @@ -264,18 +285,18 @@ func (s *SchedulerSnapshotService) pollOutbox() { s.checkOutboxLag(ctx, events[0], watermarkForCheck) } -func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent) error { +func (s *SchedulerSnapshotService) handleOutboxEvent(ctx context.Context, event SchedulerOutboxEvent, seen map[batchSeenKey]struct{}) error { switch event.EventType { case SchedulerOutboxEventAccountLastUsed: return s.handleLastUsedEvent(ctx, event.Payload) case SchedulerOutboxEventAccountBulkChanged: - return s.handleBulkAccountEvent(ctx, event.Payload) + return s.handleBulkAccountEvent(ctx, event.Payload, seen) case SchedulerOutboxEventAccountGroupsChanged: - return s.handleAccountEvent(ctx, event.AccountID, event.Payload) + return s.handleAccountEvent(ctx, event.AccountID, event.Payload, seen) case SchedulerOutboxEventAccountChanged: - return s.handleAccountEvent(ctx, event.AccountID, event.Payload) + return s.handleAccountEvent(ctx, event.AccountID, event.Payload, seen) case SchedulerOutboxEventGroupChanged: - return s.handleGroupEvent(ctx, event.GroupID) + return s.handleGroupEvent(ctx, event.GroupID, seen) case SchedulerOutboxEventFullRebuild: return s.triggerFullRebuild("outbox") default: @@ -309,7 +330,7 @@ func (s *SchedulerSnapshotService) handleLastUsedEvent(ctx context.Context, payl return s.cache.UpdateLastUsed(ctx, updates) } -func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any) error { +func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, payload map[string]any, seen map[batchSeenKey]struct{}) error { if payload == nil { return nil } @@ -323,15 +344,15 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p } ids := make([]int64, 0, len(rawIDs)) - seen := make(map[int64]struct{}, len(rawIDs)) + seenIDs := make(map[int64]struct{}, len(rawIDs)) for _, id := range rawIDs { if id <= 0 { continue } - if _, exists := seen[id]; exists { + if _, exists := seenIDs[id]; exists { continue } - seen[id] = struct{}{} + seenIDs[id] = struct{}{} ids = append(ids, id) } if len(ids) == 0 { @@ -384,10 +405,10 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p for gid := range rebuildGroupSet { rebuildGroupIDs = append(rebuildGroupIDs, gid) } - return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change") + return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change", seen) } -func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error { +func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any, seen map[batchSeenKey]struct{}) error { if accountID == nil || *accountID <= 0 { return nil } @@ -408,7 +429,7 @@ func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accou return err } } - return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss") + return s.rebuildByGroupIDs(ctx, groupIDs, "account_miss", seen) } return err } @@ -420,18 +441,18 @@ func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accou if len(groupIDs) == 0 { groupIDs = account.GroupIDs } - return s.rebuildByAccount(ctx, account, groupIDs, "account_change") + return s.rebuildByAccount(ctx, account, groupIDs, "account_change", seen) } -func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64) error { +func (s *SchedulerSnapshotService) handleGroupEvent(ctx context.Context, groupID *int64, seen map[batchSeenKey]struct{}) error { if groupID == nil || *groupID <= 0 { return nil } groupIDs := []int64{*groupID} - return s.rebuildByGroupIDs(ctx, groupIDs, "group_change") + return s.rebuildByGroupIDs(ctx, groupIDs, "group_change", seen) } -func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string) error { +func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account *Account, groupIDs []int64, reason string, seen map[batchSeenKey]struct{}) error { if account == nil { return nil } @@ -441,21 +462,21 @@ func (s *SchedulerSnapshotService) rebuildByAccount(ctx context.Context, account } var firstErr error - if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason); err != nil && firstErr == nil { + if err := s.rebuildBucketsForPlatform(ctx, account.Platform, groupIDs, reason, seen); err != nil && firstErr == nil { firstErr = err } if account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() { - if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason); err != nil && firstErr == nil { + if err := s.rebuildBucketsForPlatform(ctx, PlatformAnthropic, groupIDs, reason, seen); err != nil && firstErr == nil { firstErr = err } - if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason); err != nil && firstErr == nil { + if err := s.rebuildBucketsForPlatform(ctx, PlatformGemini, groupIDs, reason, seen); err != nil && firstErr == nil { firstErr = err } } return firstErr } -func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string) error { +func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupIDs []int64, reason string, seen map[batchSeenKey]struct{}) error { groupIDs = s.normalizeGroupIDs(groupIDs) if len(groupIDs) == 0 { return nil @@ -463,19 +484,30 @@ func (s *SchedulerSnapshotService) rebuildByGroupIDs(ctx context.Context, groupI platforms := []string{PlatformAnthropic, PlatformGemini, PlatformOpenAI, PlatformAntigravity} var firstErr error for _, platform := range platforms { - if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason); err != nil && firstErr == nil { + if err := s.rebuildBucketsForPlatform(ctx, platform, groupIDs, reason, seen); err != nil && firstErr == nil { firstErr = err } } return firstErr } -func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string) error { +func (s *SchedulerSnapshotService) rebuildBucketsForPlatform(ctx context.Context, platform string, groupIDs []int64, reason string, seen map[batchSeenKey]struct{}) error { if platform == "" { return nil } var firstErr error for _, gid := range groupIDs { + // Within a single poll batch, skip (groupID, platform) pairs that were + // already rebuilt. The first rebuild loads fresh DB data for all accounts + // in the group, so subsequent rebuilds for the same group+platform within + // the same batch are redundant. + if seen != nil { + key := batchSeenKey{gid, platform} + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + } if err := s.rebuildBucket(ctx, SchedulerBucket{GroupID: gid, Platform: platform, Mode: SchedulerModeSingle}, reason); err != nil && firstErr == nil { firstErr = err } -- GitLab From 697c41a3f6d67671942ef941273095b02043dfc7 Mon Sep 17 00:00:00 2001 From: Elysia <1628615876@qq.com> Date: Thu, 16 Apr 2026 20:41:40 +0800 Subject: [PATCH 008/265] fix: create fresh context per watermark write retry attempt Each retry in the SetOutboxWatermark loop now gets its own 5s context. Previously a shared context could already be expired when the second or third attempt ran, making the retries pointless. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- backend/internal/service/scheduler_snapshot_service.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 5ead45bc..62b6993d 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -264,11 +264,11 @@ func (s *SchedulerSnapshotService) pollOutbox() { } lastID := events[len(events)-1].ID - wmCtx, wmCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer wmCancel() var wmErr error for i := range 3 { + wmCtx, wmCancel := context.WithTimeout(context.Background(), 5*time.Second) wmErr = s.cache.SetOutboxWatermark(wmCtx, lastID) + wmCancel() if wmErr == nil { break } -- GitLab From a789c8c4c70c72aa02a81736fc8ab8d9b8571f14 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 17 Apr 2026 09:37:25 +0800 Subject: [PATCH 009/265] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81opus-4.7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/domain/constants.go | 2 + .../internal/pkg/antigravity/claude_types.go | 1 + .../pkg/antigravity/request_transformer.go | 12 +- backend/internal/pkg/claude/constants.go | 6 + backend/internal/service/billing_service.go | 6 + backend/internal/service/pricing_service.go | 114 +++++++++++------- frontend/src/composables/useModelWhitelist.ts | 7 +- 7 files changed, 101 insertions(+), 47 deletions(-) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 429486c3..a57f7067 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -71,6 +71,7 @@ const ( // 与前端 useModelWhitelist.ts 中的 antigravityDefaultMappings 保持一致 var DefaultAntigravityModelMapping = map[string]string{ // Claude 白名单 + "claude-opus-4-7": "claude-opus-4-7", // 官方模型 "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", // 官方模型 "claude-opus-4-6": "claude-opus-4-6-thinking", // 简称映射 "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", // 迁移旧模型 @@ -120,6 +121,7 @@ var DefaultAntigravityModelMapping = map[string]string{ // aws_region 自动调整为匹配的区域前缀(如 eu.、apac.、jp. 等) var DefaultBedrockModelMapping = map[string]string{ // Claude Opus + "claude-opus-4-7": "us.anthropic.claude-opus-4-7-v1", "claude-opus-4-6-thinking": "us.anthropic.claude-opus-4-6-v1", "claude-opus-4-6": "us.anthropic.claude-opus-4-6-v1", "claude-opus-4-5-thinking": "us.anthropic.claude-opus-4-5-20251101-v1:0", diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index ce144bb9..0b8ae5f2 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -154,6 +154,7 @@ var claudeModels = []modelDef{ {ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"}, {ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"}, {ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"}, + {ID: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", CreatedAt: "2026-04-17T00:00:00Z"}, {ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"}, } diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index d13a8498..b5de8166 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -582,8 +582,12 @@ func maxOutputTokensLimit(model string) int { return maxOutputTokensUpperBound } -func isAntigravityOpus46Model(model string) bool { - return strings.HasPrefix(strings.ToLower(model), "claude-opus-4-6") +// isAntigravityOpusHighTierModel 判断是否为高阶 Opus 模型(4.6+), +// 用于 adaptive thinking 时覆写为高预算。 +func isAntigravityOpusHighTierModel(model string) bool { + lower := strings.ToLower(model) + return strings.HasPrefix(lower, "claude-opus-4-6") || + strings.HasPrefix(lower, "claude-opus-4-7") } func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { @@ -605,12 +609,12 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { } // - thinking.type=enabled:budget_tokens>0 用显式预算 - // - thinking.type=adaptive:仅在 Antigravity 的 Opus 4.6 上覆写为 (24576) + // - thinking.type=adaptive:在 Antigravity 的高阶 Opus(4.6+)上覆写为 (24576) budget := -1 if req.Thinking.BudgetTokens > 0 { budget = req.Thinking.BudgetTokens } - if req.Thinking.Type == "adaptive" && isAntigravityOpus46Model(req.Model) { + if req.Thinking.Type == "adaptive" && isAntigravityOpusHighTierModel(req.Model) { budget = ClaudeAdaptiveHighThinkingBudgetTokens } diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index dfca252f..21c723d2 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -83,6 +83,12 @@ var DefaultModels = []Model{ DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-06T00:00:00Z", }, + { + ID: "claude-opus-4-7", + Type: "model", + DisplayName: "Claude Opus 4.7", + CreatedAt: "2026-04-17T00:00:00Z", + }, { ID: "claude-sonnet-4-6", Type: "model", diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 763abadb..32a54cbe 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -191,6 +191,9 @@ func (s *BillingService) initFallbackPricing() { // Claude 4.6 Opus (与4.5同价) s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"] + // Claude 4.7 Opus (暂与4.6同价,待官方定价更新) + s.fallbackPrices["claude-opus-4.7"] = s.fallbackPrices["claude-opus-4.6"] + // Gemini 3.1 Pro s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{ InputPricePerToken: 2e-6, // $2 per MTok @@ -278,6 +281,9 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { // 按模型系列匹配 if strings.Contains(modelLower, "opus") { + if strings.Contains(modelLower, "4.7") || strings.Contains(modelLower, "4-7") { + return s.fallbackPrices["claude-opus-4.7"] + } if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") { return s.fallbackPrices["claude-opus-4.6"] } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 3b3f31c3..2bf48702 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -656,65 +656,95 @@ func (s *PricingService) extractBaseName(model string) string { // matchByModelFamily 基于模型系列匹配 func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { - // Claude模型系列匹配规则 - familyPatterns := map[string][]string{ - "opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"}, - "opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"}, - "opus-4": {"claude-opus-4", "claude-3-opus"}, - "sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"}, - "sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"}, - "sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"}, - "sonnet-3": {"claude-3-sonnet"}, - "haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"}, - "haiku-3": {"claude-3-haiku"}, - } - - // 确定模型属于哪个系列 - var matchedFamily string - for family, patterns := range familyPatterns { - for _, pattern := range patterns { + // modelFamily 定义一个模型系列的匹配和定价查找规则。 + type modelFamily struct { + name string // 系列名称 + match []string // 用于将模型归类到此系列的模式(strings.Contains 匹配) + pricing []string // 用于在定价数据中查找价格的模式(nil 则复用 match;可包含低版本 fallback) + } + + // 按特异性降序排列:高版本号在前,避免 "claude-opus-4"(opus-4 系列) + // 因子串关系误匹配 "claude-opus-4-7"(opus-4.7 系列)。 + // 注意:原 map 实现存在 Go map 迭代随机性导致的同类 bug,此处改为有序切片修复。 + families := []modelFamily{ + {name: "opus-4.7", match: []string{"claude-opus-4-7", "claude-opus-4.7"}, pricing: []string{"claude-opus-4-7", "claude-opus-4.7", "claude-opus-4-6"}}, + {name: "opus-4.6", match: []string{"claude-opus-4-6", "claude-opus-4.6"}}, + {name: "opus-4.5", match: []string{"claude-opus-4-5", "claude-opus-4.5"}}, + {name: "opus-4", match: []string{"claude-opus-4", "claude-3-opus"}}, + {name: "sonnet-4.5", match: []string{"claude-sonnet-4-5", "claude-sonnet-4.5"}}, + {name: "sonnet-4", match: []string{"claude-sonnet-4", "claude-3-5-sonnet"}}, + {name: "sonnet-3.5", match: []string{"claude-3-5-sonnet", "claude-3.5-sonnet"}}, + {name: "sonnet-3", match: []string{"claude-3-sonnet"}}, + {name: "haiku-3.5", match: []string{"claude-3-5-haiku", "claude-3.5-haiku"}}, + {name: "haiku-3", match: []string{"claude-3-haiku"}}, + } + + // Phase 1: 按有序切片归类(最具体的系列优先匹配) + var matched *modelFamily + for i := range families { + for _, pattern := range families[i].match { if strings.Contains(model, pattern) || strings.Contains(model, strings.ReplaceAll(pattern, "-", "")) { - matchedFamily = family + matched = &families[i] break } } - if matchedFamily != "" { + if matched != nil { break } } - if matchedFamily == "" { - // 简单的系列匹配 - if strings.Contains(model, "opus") { - if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") { - matchedFamily = "opus-4.5" - } else { - matchedFamily = "opus-4" + // Phase 2: 二次兜底——当模型 ID 不含已知模式串时,按关键字粗分 + if matched == nil { + var fallbackName string + switch { + case strings.Contains(model, "opus"): + switch { + case strings.Contains(model, "4.7") || strings.Contains(model, "4-7"): + fallbackName = "opus-4.7" + case strings.Contains(model, "4.6") || strings.Contains(model, "4-6"): + fallbackName = "opus-4.6" + case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"): + fallbackName = "opus-4.5" + default: + fallbackName = "opus-4" } - } else if strings.Contains(model, "sonnet") { - if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") { - matchedFamily = "sonnet-4.5" - } else if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") { - matchedFamily = "sonnet-3.5" - } else { - matchedFamily = "sonnet-4" + case strings.Contains(model, "sonnet"): + switch { + case strings.Contains(model, "4.5") || strings.Contains(model, "4-5"): + fallbackName = "sonnet-4.5" + case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"): + fallbackName = "sonnet-3.5" + default: + fallbackName = "sonnet-4" } - } else if strings.Contains(model, "haiku") { - if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") { - matchedFamily = "haiku-3.5" - } else { - matchedFamily = "haiku-3" + case strings.Contains(model, "haiku"): + switch { + case strings.Contains(model, "3-5") || strings.Contains(model, "3.5"): + fallbackName = "haiku-3.5" + default: + fallbackName = "haiku-3" + } + } + if fallbackName != "" { + for i := range families { + if families[i].name == fallbackName { + matched = &families[i] + break + } } } } - if matchedFamily == "" { + if matched == nil { return nil } - // 在价格数据中查找该系列的模型 - patterns := familyPatterns[matchedFamily] - for _, pattern := range patterns { + // Phase 3: 在定价数据中查找该系列的价格 + lookups := matched.pricing + if lookups == nil { + lookups = matched.match + } + for _, pattern := range lookups { for key, pricing := range s.pricingData { keyLower := strings.ToLower(key) if strings.Contains(keyLower, pattern) { diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index d16ed977..a282ae7d 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -43,6 +43,7 @@ export const claudeModels = [ 'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001', 'claude-opus-4-5-20251101', 'claude-opus-4-6', + 'claude-opus-4-7', 'claude-sonnet-4-6', 'claude-2.1', 'claude-2.0', 'claude-instant-1.2' ] @@ -66,6 +67,7 @@ const antigravityModels = [ // Claude 4.5+ 系列 'claude-opus-4-6', 'claude-opus-4-6-thinking', + 'claude-opus-4-7', 'claude-opus-4-5-thinking', 'claude-sonnet-4-6', 'claude-sonnet-4-5', @@ -250,6 +252,7 @@ const anthropicPresetMappings = [ { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: 'Opus 4.5', from: 'claude-opus-4-5-20251101', to: 'claude-opus-4-5-20251101', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, + { label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Haiku 3.5', from: 'claude-3-5-haiku-20241022', to: 'claude-3-5-haiku-20241022', color: 'bg-green-100 text-green-700 hover:bg-green-200 dark:bg-green-900/30 dark:text-green-400' }, { label: 'Haiku 4.5', from: 'claude-haiku-4-5-20251001', to: 'claude-haiku-4-5-20251001', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' }, { label: 'Opus->Sonnet', from: 'claude-opus-4-6', to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' } @@ -309,12 +312,14 @@ const antigravityPresetMappings = [ { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, { label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'claude-sonnet-4-5', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }, - { label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' } + { label: 'Opus 4.6-thinking', from: 'claude-opus-4-6-thinking', to: 'claude-opus-4-6-thinking', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }, + { label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'claude-opus-4-7', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' } ] // Bedrock 预设映射(与后端 DefaultBedrockModelMapping 保持一致) const bedrockPresetMappings = [ { label: 'Opus 4.6', from: 'claude-opus-4-6', to: 'us.anthropic.claude-opus-4-6-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }, + { label: 'Opus 4.7', from: 'claude-opus-4-7', to: 'us.anthropic.claude-opus-4-7-v1', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }, { label: 'Sonnet 4.6', from: 'claude-sonnet-4-6', to: 'us.anthropic.claude-sonnet-4-6', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, { label: 'Opus 4.5', from: 'claude-opus-4-5-thinking', to: 'us.anthropic.claude-opus-4-5-20251101-v1:0', color: 'bg-pink-100 text-pink-700 hover:bg-pink-200 dark:bg-pink-900/30 dark:text-pink-400' }, { label: 'Sonnet 4.5', from: 'claude-sonnet-4-5', to: 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, -- GitLab From 5d586a9f3a18d724d7883fa636328400c302e261 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 17 Apr 2026 10:17:50 +0800 Subject: [PATCH 010/265] =?UTF-8?q?fix:=20=E4=B8=8A=E6=B8=B8=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=20KYC=20=E8=BA=AB=E4=BB=BD=E9=AA=8C=E8=AF=81=E8=A6=81?= =?UTF-8?q?=E6=B1=82=E6=97=B6=E5=81=9C=E6=AD=A2=E8=B4=A6=E5=8F=B7=E8=B0=83?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/ratelimit_service.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 4d8009b7..53581574 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -152,6 +152,11 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc msg := "Credit balance exhausted (400): " + upstreamMsg s.handleAuthError(ctx, account, msg) shouldDisable = true + } else if strings.Contains(strings.ToLower(upstreamMsg), "identity verification is required") { + // KYC 身份验证要求 → 永久禁用,账号需完成身份验证后才能恢复 + msg := "Identity verification required (400): " + upstreamMsg + s.handleAuthError(ctx, account, msg) + shouldDisable = true } // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: -- GitLab From 6cfdf4ec051fcac04456367b648622bce175ad7d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 17 Apr 2026 02:51:18 +0000 Subject: [PATCH 011/265] chore: sync VERSION to 0.1.114 [skip ci] --- backend/cmd/server/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index c21e67e6..c29f5f75 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.113 +0.1.114 -- GitLab From 11b97147d5dfba235f5398e94376a8e8dffb6e54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E9=9D=9E=E4=BB=99?= <1506088208@qq.com> Date: Fri, 17 Apr 2026 14:59:19 +0800 Subject: [PATCH 012/265] fix: update CI rules to allow tags starting with 'v', 's', or 't' --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a1b1c7e9..f493a3d3 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -47,7 +47,7 @@ sub2api: VERSION: "${CI_COMMIT_TAG:-}" COMMIT: "${CI_COMMIT_SHA}" rules: - - if: $CI_COMMIT_TAG =~ /^sub2api\/v/ + - if: $CI_COMMIT_TAG =~ /^sub2api\/[vst]/ when: always - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH when: always -- GitLab From fd0c9a130530eeab28c212cec601a3ef5ba1f843 Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 17:00:29 +0800 Subject: [PATCH 013/265] fix(payment): store provider config as plaintext JSON with legacy ciphertext fallback Without TOTP_ENCRYPTION_KEY, saved payment configs were lost on restart because the AES round-trip failed silently. Write new records as plaintext JSON; read path tries JSON first, falls back to legacy AES decrypt when a key is present, and treats unreadable values as empty so admins can re-enter them via the UI. --- backend/internal/payment/crypto.go | 11 ++- backend/internal/payment/load_balancer.go | 30 ++++-- .../internal/payment/load_balancer_test.go | 97 +++++++++++++++++++ .../service/payment_config_providers.go | 38 +++++--- 4 files changed, 151 insertions(+), 25 deletions(-) diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go index e39e957f..5467e50b 100644 --- a/backend/internal/payment/crypto.go +++ b/backend/internal/payment/crypto.go @@ -10,12 +10,15 @@ import ( "strings" ) +// AES256KeySize is the required key length (in bytes) for AES-256-GCM. +const AES256KeySize = 32 + // Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key. // The output format is "iv:authTag:ciphertext" where each component is base64-encoded, // matching the Node.js crypto.ts format for cross-compatibility. func Encrypt(plaintext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } block, err := aes.NewCipher(key) @@ -52,8 +55,8 @@ func Encrypt(plaintext string, key []byte) (string, error) { // Decrypt decrypts a ciphertext string produced by Encrypt. // The input format is "iv:authTag:ciphertext" where each component is base64-encoded. func Decrypt(ciphertext string, key []byte) (string, error) { - if len(key) != 32 { - return "", fmt.Errorf("encryption key must be 32 bytes, got %d", len(key)) + if len(key) != AES256KeySize { + return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) } parts := strings.SplitN(ciphertext, ":", 3) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index f0353173..52a1b011 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -261,6 +261,9 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns if err != nil { return nil, fmt.Errorf("decrypt instance %d config: %w", selected.ID, err) } + if config == nil { + config = map[string]string{} + } if selected.PaymentMode != "" { config["paymentMode"] = selected.PaymentMode @@ -275,16 +278,29 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns }, nil } -func (lb *DefaultLoadBalancer) decryptConfig(encrypted string) (map[string]string, error) { - plaintext, err := Decrypt(encrypted, lb.encryptionKey) - if err != nil { - return nil, err +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext. +// Unreadable values (legacy ciphertext without a valid key, or malformed data) +// are treated as empty so the service keeps running while the admin re-enters +// the config via the UI. +func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { + return nil, nil } var config map[string]string - if err := json.Unmarshal([]byte(plaintext), &config); err != nil { - return nil, fmt.Errorf("unmarshal config: %w", err) + if err := json.Unmarshal([]byte(stored), &config); err == nil { + return config, nil + } + if len(lb.encryptionKey) == AES256KeySize { + if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &config); err == nil { + return config, nil + } + } } - return config, nil + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } // GetInstanceDailyAmount returns the total completed order amount for an instance today. diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 04b3c25b..2bf4f6ac 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -452,6 +452,103 @@ func TestStartOfDay(t *testing.T) { } } +func TestDecryptConfig_PlaintextAndLegacyCompat(t *testing.T) { + t.Parallel() + + key := make([]byte, AES256KeySize) + for i := range key { + key[i] = byte(i + 1) + } + wrongKey := make([]byte, AES256KeySize) + for i := range wrongKey { + wrongKey[i] = byte(0xFF - i) + } + + plaintextJSON := `{"appId":"app-123","secret":"sec-xyz"}` + + legacyEncrypted, err := Encrypt(plaintextJSON, key) + if err != nil { + t.Fatalf("seed Encrypt: %v", err) + } + + tests := []struct { + name string + stored string + key []byte + want map[string]string + }{ + { + name: "empty stored returns nil map", + stored: "", + key: key, + want: nil, + }, + { + name: "plaintext JSON parses directly", + stored: plaintextJSON, + key: nil, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "plaintext JSON works even with key present", + stored: plaintextJSON, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with correct key decrypts", + stored: legacyEncrypted, + key: key, + want: map[string]string{"appId": "app-123", "secret": "sec-xyz"}, + }, + { + name: "legacy ciphertext with no key treated as empty", + stored: legacyEncrypted, + key: nil, + want: nil, + }, + { + name: "legacy ciphertext with wrong key treated as empty", + stored: legacyEncrypted, + key: wrongKey, + want: nil, + }, + { + name: "garbage data treated as empty", + stored: "not-json-and-not-ciphertext", + key: key, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + lb := NewDefaultLoadBalancer(nil, tt.key) + got, err := lb.decryptConfig(tt.stored) + if err != nil { + t.Fatalf("decryptConfig unexpected error: %v", err) + } + if !stringMapEqual(got, tt.want) { + t.Fatalf("decryptConfig = %v, want %v", got, tt.want) + } + }) + } +} + +// stringMapEqual compares two map[string]string values; nil and empty are equal. +func stringMapEqual(a, b map[string]string) bool { + if len(a) != len(b) { + return false + } + for k, v := range a { + if bv, ok := b[k]; !ok || bv != v { + return false + } + } + return true +} + // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 3c406b45..59337ad6 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "strconv" "strings" @@ -290,19 +291,29 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon return existing, nil } -func (s *PaymentConfigService) decryptConfig(encrypted string) (map[string]string, error) { - if encrypted == "" { +// decryptConfig parses a stored provider config. +// New records are plaintext JSON; legacy records are AES-256-GCM ciphertext +// ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including +// legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty, +// letting the admin re-enter the config via the UI to complete the migration. +func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) { + if stored == "" { return nil, nil } - decrypted, err := payment.Decrypt(encrypted, s.encryptionKey) - if err != nil { - return nil, fmt.Errorf("decrypt config: %w", err) + var cfg map[string]string + if err := json.Unmarshal([]byte(stored), &cfg); err == nil { + return cfg, nil } - var raw map[string]string - if err := json.Unmarshal([]byte(decrypted), &raw); err != nil { - return nil, fmt.Errorf("unmarshal decrypted config: %w", err) + if len(s.encryptionKey) == payment.AES256KeySize { + if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil { + if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil { + return cfg, nil + } + } } - return raw, nil + slog.Warn("payment provider config unreadable, treating as empty for re-entry", + "stored_len", len(stored)) + return nil, nil } func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id int64) error { @@ -317,14 +328,13 @@ func (s *PaymentConfigService) DeleteProviderInstance(ctx context.Context, id in return s.entClient.PaymentProviderInstance.DeleteOneID(id).Exec(ctx) } +// encryptConfig serialises a provider config for storage. +// New records are written as plaintext JSON; the historical AES-GCM wrapping +// has been dropped but decryptConfig still accepts old ciphertext during migration. func (s *PaymentConfigService) encryptConfig(cfg map[string]string) (string, error) { data, err := json.Marshal(cfg) if err != nil { return "", fmt.Errorf("marshal config: %w", err) } - enc, err := payment.Encrypt(string(data), s.encryptionKey) - if err != nil { - return "", fmt.Errorf("encrypt config: %w", err) - } - return enc, nil + return string(data), nil } -- GitLab From 44cdef7934168f167c5d433e5947aea3ac5a279d Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 17:00:45 +0800 Subject: [PATCH 014/265] fix(usage): subscription billing honours group rate multiplier MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Subscription-mode billing was consuming quota at TotalCost (raw) instead of ActualCost (TotalCost * RateMultiplier), so per-group rate multipliers — including free subscriptions (multiplier = 0) — were silently ignored. Switch the three subscription cost writes in buildUsageBillingCommand, finalizePostUsageBilling, and the legacy postUsageBilling fallback to ActualCost, and add a table-driven test covering 2x / 0.5x / free multipliers plus a balance-mode regression check. --- backend/internal/service/gateway_service.go | 16 ++-- ...teway_service_subscription_billing_test.go | 85 +++++++++++++++++++ 2 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 backend/internal/service/gateway_service_subscription_billing_test.go diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4b4fc0bf..07a9e41c 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7317,8 +7317,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill cost := p.Cost if p.IsSubscriptionBill { - if cost.TotalCost > 0 { - if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { + // Subscription usage tracked by ActualCost so group rate multiplier + // consumes the quota at the expected speed. + if cost.ActualCost > 0 { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil { slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) } } @@ -7417,9 +7419,13 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage } } + // Record subscription / balance cost using ActualCost so the group (and any + // user-specific) rate multiplier consumes subscription quota at the expected + // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards + // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0). if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { cmd.SubscriptionID = &p.Subscription.ID - cmd.SubscriptionCost = p.Cost.TotalCost + cmd.SubscriptionCost = p.Cost.ActualCost } else if p.Cost.ActualCost > 0 { cmd.BalanceCost = p.Cost.ActualCost } @@ -7478,8 +7484,8 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, resu } if p.IsSubscriptionBill { - if p.Cost.TotalCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { - deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.TotalCost) + if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost) } } else if p.Cost.ActualCost > 0 && p.User != nil { deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) diff --git a/backend/internal/service/gateway_service_subscription_billing_test.go b/backend/internal/service/gateway_service_subscription_billing_test.go new file mode 100644 index 00000000..42a81035 --- /dev/null +++ b/backend/internal/service/gateway_service_subscription_billing_test.go @@ -0,0 +1,85 @@ +//go:build unit + +package service + +import ( + "testing" +) + +// TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier locks in the fix +// that subscription-mode billing honours the group (and any user-specific) rate +// multiplier — i.e. cmd.SubscriptionCost tracks ActualCost (= TotalCost * +// RateMultiplier), not raw TotalCost. +func TestBuildUsageBillingCommand_SubscriptionAppliesRateMultiplier(t *testing.T) { + t.Parallel() + + groupID := int64(7) + subID := int64(42) + + tests := []struct { + name string + totalCost float64 + actualCost float64 + isSubscription bool + wantSub float64 + wantBalance float64 + }{ + { + name: "subscription with 2x multiplier consumes 2x quota", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: true, + wantSub: 2.0, + wantBalance: 0, + }, + { + name: "subscription with 0.5x multiplier consumes 0.5x quota", + totalCost: 1.0, + actualCost: 0.5, + isSubscription: true, + wantSub: 0.5, + wantBalance: 0, + }, + { + name: "free subscription (multiplier 0) consumes no quota", + totalCost: 1.0, + actualCost: 0, + isSubscription: true, + wantSub: 0, + wantBalance: 0, + }, + { + name: "balance billing keeps using ActualCost (regression)", + totalCost: 1.0, + actualCost: 2.0, + isSubscription: false, + wantSub: 0, + wantBalance: 2.0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p := &postUsageBillingParams{ + Cost: &CostBreakdown{TotalCost: tt.totalCost, ActualCost: tt.actualCost}, + User: &User{ID: 1}, + APIKey: &APIKey{ID: 2, GroupID: &groupID}, + Account: &Account{ID: 3}, + Subscription: &UserSubscription{ID: subID}, + IsSubscriptionBill: tt.isSubscription, + } + + cmd := buildUsageBillingCommand("req-1", nil, p) + if cmd == nil { + t.Fatal("buildUsageBillingCommand returned nil") + } + if cmd.SubscriptionCost != tt.wantSub { + t.Errorf("SubscriptionCost = %v, want %v", cmd.SubscriptionCost, tt.wantSub) + } + if cmd.BalanceCost != tt.wantBalance { + t.Errorf("BalanceCost = %v, want %v", cmd.BalanceCost, tt.wantBalance) + } + }) + } +} -- GitLab From 948d8e6d024412cc2efda34ec41540fddb4bfb0b Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 17:01:01 +0800 Subject: [PATCH 015/265] fix(admin): prevent browser password manager from autofilling account API key Chrome's password manager matched the apikey-type account's Base URL + API Key inputs as a login form and autofilled the last saved password by domain, so editing a Gemini account could overwrite its apikey with a Claude key that shared the same Base URL. Add autocomplete="new-password" plus data-*-ignore attributes for 1Password / LastPass / Bitwarden to opt the field out of every major password manager's autofill. --- frontend/src/components/account/EditAccountModal.vue | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 1da32e2c..59ca0b9c 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -52,6 +52,10 @@ v-model="editApiKey" type="password" class="input font-mono" + autocomplete="new-password" + data-1p-ignore + data-lpignore="true" + data-bwignore="true" :placeholder=" account.platform === 'openai' ? 'sk-proj-...' -- GitLab From df57d2776b1a74e37470970f1bcf8942ad810e1d Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 18:32:12 +0800 Subject: [PATCH 016/265] fix(billing): reject rate_multiplier <= 0 on save; clamp negatives to 0 in compute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 分组倍率和用户专属倍率在保存时没有校验,0 会触发计费层的 `<=0 → 1.0` 防御条款,结果订阅/余额分组按标准价扣费;完全是沉默地绕过了业务规则。 - 保存校验(admin_service):CreateGroup / UpdateGroup / BatchSetGroupRateMultipliers / UpdateUser.SyncUserGroupRates 全部要求 > 0 - 计算层(billing_service):三处 `<=0 → 1.0` 改为 `<0 → 0`;负数按 0 结算, 避免配置异常被静默按 1x 收费 - 前端:分组倍率 / 用户专属倍率输入 min 统一到 0.001 - 删除未使用的 IsFreeSubscription 方法 测试:新增 billing_service_rate_multiplier_test.go 端到端验证;更新原有锁定 旧 `<=0 → 1.0` 行为的测试。 --- backend/internal/service/admin_service.go | 21 +++++++ .../service/admin_service_group_test.go | 6 ++ backend/internal/service/billing_service.go | 16 ++--- .../service/billing_service_image_test.go | 5 +- .../billing_service_rate_multiplier_test.go | 63 +++++++++++++++++++ .../internal/service/billing_service_test.go | 28 --------- .../service/billing_service_unified_test.go | 40 ++++-------- backend/internal/service/group.go | 4 -- .../openai_gateway_record_usage_test.go | 2 +- .../admin/group/GroupRateMultipliersModal.vue | 2 +- .../admin/user/UserAllowedGroupsModal.vue | 4 +- 11 files changed, 119 insertions(+), 72 deletions(-) create mode 100644 backend/internal/service/billing_service_rate_multiplier_test.go diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7c26a47c..701f3659 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -586,6 +586,15 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI } func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率) + if input.GroupRates != nil { + for groupID, rate := range input.GroupRates { + if rate != nil && *rate <= 0 { + return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID) + } + } + } + user, err := s.userRepo.GetByID(ctx, id) if err != nil { return nil, err @@ -811,6 +820,10 @@ func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, erro } func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + if input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } + platform := input.Platform if platform == "" { platform = PlatformAnthropic @@ -1050,6 +1063,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.Platform = input.Platform } if input.RateMultiplier != nil { + if *input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } group.RateMultiplier = *input.RateMultiplier } if input.IsExclusive != nil { @@ -1286,6 +1302,11 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro if s.userGroupRateRepo == nil { return nil } + for _, e := range entries { + if e.RateMultiplier <= 0 { + return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID) + } + } return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index a4c6d0ca..41d2c26a 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -621,6 +621,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatfo _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformOpenAI, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -641,6 +642,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *t _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeSubscription, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -695,6 +697,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t * _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -713,6 +716,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -733,6 +737,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAntigravity, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &fallbackID, }) @@ -750,6 +755,7 @@ func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing. group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ Name: "g1", Platform: PlatformAnthropic, + RateMultiplier: 1.0, SubscriptionType: SubscriptionTypeStandard, FallbackGroupIDOnInvalidRequest: &zero, }) diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 32a54cbe..c9f32b3b 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -448,8 +448,9 @@ func (s *BillingService) CalculateCostUnified(input CostInput) (*CostBreakdown, }) } - if input.RateMultiplier <= 0 { - input.RateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏(缓存/迁移残留),按 0 处理避免按 1x 误扣。 + if input.RateMultiplier < 0 { + input.RateMultiplier = 0 } var breakdown *CostBreakdown @@ -493,8 +494,9 @@ func (s *BillingService) computeTokenBreakdown( rateMultiplier float64, serviceTier string, applyLongCtx bool, ) *CostBreakdown { - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 保存时强制 > 0;若仍有负数泄漏,按 0 处理避免按 1x 误扣。 + if rateMultiplier < 0 { + rateMultiplier = 0 } inputPrice := pricing.InputPricePerToken @@ -831,9 +833,9 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag // 计算总费用 totalCost := unitPrice * float64(imageCount) - // 应用倍率 - if rateMultiplier <= 0 { - rateMultiplier = 1.0 + // 应用倍率(保存时强制 > 0;负数按 0 处理避免按 1x 误扣) + if rateMultiplier < 0 { + rateMultiplier = 0 } actualCost := totalCost * rateMultiplier diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index fa90f6bb..8d3ca987 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -90,13 +90,14 @@ func TestCalculateImageCost_NegativeCount(t *testing.T) { require.Equal(t, 0.0, cost.ActualCost) } -// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0 +// TestCalculateImageCost_ZeroRateMultiplier 锁定新行为:倍率 0 直接按 0 计费 +// (保存时已强制 > 0;若仍有 0 泄漏到计费层,零消耗比历史的 1.0 更安全)。 func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) { svc := &BillingService{} cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0) require.InDelta(t, 0.201, cost.TotalCost, 0.0001) - require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理 + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } // TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格 diff --git a/backend/internal/service/billing_service_rate_multiplier_test.go b/backend/internal/service/billing_service_rate_multiplier_test.go new file mode 100644 index 00000000..83788196 --- /dev/null +++ b/backend/internal/service/billing_service_rate_multiplier_test.go @@ -0,0 +1,63 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestCalculateCost_RateMultiplier_NegativeClampedToZero 锁定负数倍率被 +// 钳制为 0(而非历史上的 1.0),避免配置异常导致静默按标准价扣费。 +func TestCalculateCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 // ActualCost / TotalCost + }{ + {"negative clamped to 0", -1.5, 0}, + {"zero passes through as 0 (defense in depth)", 0, 0}, + {"positive 2x applied", 2.0, 2.0}, + {"positive 0.5x applied", 0.5, 0.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, tt.multiplier) + require.NoError(t, err) + require.Greater(t, cost.TotalCost, 0.0, "TotalCost should be non-zero") + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} + +// TestCalculateImageCost_RateMultiplier_NegativeClampedToZero 图片按次计费路径 +// 同样遵循"负数 → 0"语义。 +func TestCalculateImageCost_RateMultiplier_NegativeClampedToZero(t *testing.T) { + svc := newTestBillingService() + price := 0.04 + cfg := &ImagePriceConfig{Price1K: &price} + + tests := []struct { + name string + multiplier float64 + wantRatio float64 + }{ + {"negative clamped to 0", -0.5, 0}, + {"zero passes through", 0, 0}, + {"positive 3x applied", 3.0, 3.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cost := svc.CalculateImageCost("imagen-3", "1K", 2, cfg, tt.multiplier) + require.NotNil(t, cost) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, tt.wantRatio*cost.TotalCost, cost.ActualCost, 1e-9) + }) + } +} diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index 2cf134e2..fc8361c7 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -71,34 +71,6 @@ func TestCalculateCost_RateMultiplier(t *testing.T) { require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) } -func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) -} - -func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { - svc := newTestBillingService() - - tokens := UsageTokens{InputTokens: 1000} - - costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) - require.NoError(t, err) - - costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) -} - func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { svc := newTestBillingService() diff --git a/backend/internal/service/billing_service_unified_test.go b/backend/internal/service/billing_service_unified_test.go index 694c3384..e6a92d1a 100644 --- a/backend/internal/service/billing_service_unified_test.go +++ b/backend/internal/service/billing_service_unified_test.go @@ -147,40 +147,35 @@ func TestCalculateCostUnified_ImageMode(t *testing.T) { require.Equal(t, string(BillingModeImage), cost.BillingMode) } -func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_RateMultiplierZeroProducesZero 锁定新行为: +// 保存时强制 > 0;若 0 仍泄漏到计费层,按 0 计费(而非历史上的 1.0)。 +func TestCalculateCostUnified_RateMultiplierZeroProducesZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} - costZero, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 0, // should default to 1.0 - Resolver: resolver, - }) - require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, - RateMultiplier: 1.0, + RateMultiplier: 0, Resolver: resolver, }) require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } -func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) { +// TestCalculateCostUnified_NegativeRateMultiplierClampedToZero 锁定新行为: +// 负数倍率按 0 计费,避免历史的 <=0 → 1.0 把配置异常静默按标准价扣费。 +func TestCalculateCostUnified_NegativeRateMultiplierClampedToZero(t *testing.T) { bs := newTestBillingService() resolver := NewModelPricingResolver(nil, bs) tokens := UsageTokens{InputTokens: 1000} - costNeg, err := bs.CalculateCostUnified(CostInput{ + cost, err := bs.CalculateCostUnified(CostInput{ Ctx: context.Background(), Model: "claude-sonnet-4", Tokens: tokens, @@ -188,17 +183,8 @@ func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) Resolver: resolver, }) require.NoError(t, err) - - costOne, err := bs.CalculateCostUnified(CostInput{ - Ctx: context.Background(), - Model: "claude-sonnet-4", - Tokens: tokens, - RateMultiplier: 1.0, - Resolver: resolver, - }) - require.NoError(t, err) - - require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) + require.Greater(t, cost.TotalCost, 0.0) + require.InDelta(t, 0.0, cost.ActualCost, 1e-10) } func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) { diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 12262613..64434ae1 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -76,10 +76,6 @@ func (g *Group) IsSubscriptionType() bool { return g.SubscriptionType == SubscriptionTypeSubscription } -func (g *Group) IsFreeSubscription() bool { - return g.IsSubscriptionType() && g.RateMultiplier == 0 -} - func (g *Group) HasDailyLimit() bool { return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0 } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index e6fa94aa..6fa8a5bd 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1031,7 +1031,7 @@ func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSetsSubscriptionFiel Model: "gpt-5.1", Duration: time.Second, }, - APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription}}, + APIKey: &APIKey{ID: 100, GroupID: i64p(88), Group: &Group{ID: 88, SubscriptionType: SubscriptionTypeSubscription, RateMultiplier: 1.0}}, User: &User{ID: 200}, Account: &Account{ID: 300}, Subscription: subscription, diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue index bf79bea2..41b2e63c 100644 --- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue +++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue @@ -166,7 +166,7 @@ Date: Fri, 17 Apr 2026 22:07:15 +0800 Subject: [PATCH 017/265] feat(gateway): raise upstream response read limit 8MB -> 128MB (configurable) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 图片生成 API 返回的 base64 内联图响应经常超过 8MB 单次读取上限,被 ReadUpstreamResponseBody 拦截成 502 upstream_error。 单张 4K PNG base64 最坏约 67MB,多张候选图或 imageSize=4K 的 image_generation 一次请求能轻松到 30MB+。把默认上限提到 128MB 能覆盖 2-3 张 4K 图,相对 请求体上限 256MB 仍有缓冲;同时抽出 config.DefaultUpstreamResponseReadMaxBytes 共享常量,viper 默认值和 service 层兜底共用,消除两处同步魔法数字。 仍可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。 --- backend/internal/config/config.go | 7 ++++++- backend/internal/service/upstream_response_limit.go | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dd9a4e58..15592905 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -52,6 +52,11 @@ const ( ConnectionPoolIsolationAccountProxy = "account_proxy" ) +// DefaultUpstreamResponseReadMaxBytes 上游非流式响应体的默认读取上限。 +// 128 MB 可容纳 2-3 张 4K PNG(base64 膨胀 33%,单张 4K PNG 最坏约 67MB base64)。 +// 可通过 gateway.upstream_response_read_max_bytes 配置项覆盖。 +const DefaultUpstreamResponseReadMaxBytes int64 = 128 * 1024 * 1024 + type Config struct { Server ServerConfig `mapstructure:"server"` Log LogConfig `mapstructure:"log"` @@ -1407,7 +1412,7 @@ func setDefaults() { viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) - viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", DefaultUpstreamResponseReadMaxBytes) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) diff --git a/backend/internal/service/upstream_response_limit.go b/backend/internal/service/upstream_response_limit.go index a0444d52..ddf0e818 100644 --- a/backend/internal/service/upstream_response_limit.go +++ b/backend/internal/service/upstream_response_limit.go @@ -12,7 +12,9 @@ import ( var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large") -const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024 +// defaultUpstreamResponseReadMaxBytes 源自 config.DefaultUpstreamResponseReadMaxBytes, +// 仅在 cfg 为 nil 时作为兜底(测试或极端场景)。 +const defaultUpstreamResponseReadMaxBytes = config.DefaultUpstreamResponseReadMaxBytes func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 { if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 { -- GitLab From 61a008f7e4e0e019f00e63409389aa5aa1058b0a Mon Sep 17 00:00:00 2001 From: erio Date: Fri, 17 Apr 2026 23:05:58 +0800 Subject: [PATCH 018/265] chore(payment): mark legacy AES ciphertext fallback as deprecated MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 明文 JSON 已经是新写入的默认格式;保留 AES 密文读取仅为兼容迁移期间的旧 记录,一旦所有部署通过管理后台重存过一次即可删除。标记为 deprecated 并加 TODO,几个版本后统一清理掉:payment.Encrypt / payment.Decrypt、两处 decryptConfig 的 AES 分支、PaymentConfigService.encryptionKey 和 DefaultLoadBalancer.encryptionKey 字段。 --- backend/internal/payment/crypto.go | 10 ++++++++++ backend/internal/payment/load_balancer.go | 7 +++++++ backend/internal/service/payment_config_providers.go | 6 ++++++ 3 files changed, 23 insertions(+) diff --git a/backend/internal/payment/crypto.go b/backend/internal/payment/crypto.go index 5467e50b..0581469d 100644 --- a/backend/internal/payment/crypto.go +++ b/backend/internal/payment/crypto.go @@ -16,6 +16,11 @@ const AES256KeySize = 32 // Encrypt encrypts plaintext using AES-256-GCM with the given 32-byte key. // The output format is "iv:authTag:ciphertext" where each component is base64-encoded, // matching the Node.js crypto.ts format for cross-compatibility. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function is kept only for seeding legacy ciphertext in tests and for +// the transitional Decrypt fallback. Scheduled for removal after all live +// deployments complete migration by re-saving their configs. func Encrypt(plaintext string, key []byte) (string, error) { if len(key) != AES256KeySize { return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) @@ -54,6 +59,11 @@ func Encrypt(plaintext string, key []byte) (string, error) { // Decrypt decrypts a ciphertext string produced by Encrypt. // The input format is "iv:authTag:ciphertext" where each component is base64-encoded. +// +// Deprecated: payment provider configs are now stored as plaintext JSON. +// This function remains only as a read-path fallback for pre-migration +// ciphertext records. Scheduled for removal once all deployments re-save +// their provider configs through the admin UI. func Decrypt(ciphertext string, key []byte) (string, error) { if len(key) != AES256KeySize { return "", fmt.Errorf("encryption key must be %d bytes, got %d", AES256KeySize, len(key)) diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index 52a1b011..ec244cd6 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -283,6 +283,11 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns // Unreadable values (legacy ciphertext without a valid key, or malformed data) // are treated as empty so the service keeps running while the admin re-enters // the config via the UI. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch below is a +// transitional compatibility shim for pre-plaintext records. Remove it (and +// the encryptionKey field + the Decrypt import) after a few releases once all +// live deployments have re-saved their provider configs through the UI. func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, error) { if stored == "" { return nil, nil @@ -291,7 +296,9 @@ func (lb *DefaultLoadBalancer) decryptConfig(stored string) (map[string]string, if err := json.Unmarshal([]byte(stored), &config); err == nil { return config, nil } + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. if len(lb.encryptionKey) == AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal if plaintext, err := Decrypt(stored, lb.encryptionKey); err == nil { if err := json.Unmarshal([]byte(plaintext), &config); err == nil { return config, nil diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 59337ad6..8e470525 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -296,6 +296,10 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon // ("iv:authTag:ciphertext"). Values that cannot be parsed as either — including // legacy ciphertext with no/invalid TOTP_ENCRYPTION_KEY — are treated as empty, // letting the admin re-enter the config via the UI to complete the migration. +// +// TODO(deprecated-legacy-ciphertext): The AES fallback branch is a transitional +// shim for pre-plaintext records. Remove it (and the encryptionKey field) after +// a few releases once all live deployments have re-saved their provider configs. func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, error) { if stored == "" { return nil, nil @@ -304,7 +308,9 @@ func (s *PaymentConfigService) decryptConfig(stored string) (map[string]string, if err := json.Unmarshal([]byte(stored), &cfg); err == nil { return cfg, nil } + // Deprecated: legacy AES-256-GCM ciphertext fallback — scheduled for removal. if len(s.encryptionKey) == payment.AES256KeySize { + //nolint:staticcheck // SA1019: intentional legacy fallback, scheduled for removal if plaintext, err := payment.Decrypt(stored, s.encryptionKey); err == nil { if err := json.Unmarshal([]byte(plaintext), &cfg); err == nil { return cfg, nil -- GitLab From 37123cef8fc14ca3d4b336a709167075ff8ea0df Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 18 Apr 2026 14:42:55 +0800 Subject: [PATCH 019/265] docs(payment): add Kyren Topup as international EasyPay provider option MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restructure the EasyPay recommendation block to present two options side by side so users can pick by funding channel and settlement currency: - Domestic / CNY — ZPay: official Alipay/WeChat API with 1.6% fee and T+1 automatic settlement (existing recommendation, expanded with fee and settlement details). - International / USDT or USD — Kyren Topup (https://kyren.top): global payment stack supporting WeChat Pay and Alipay with local-currency checkout, USD settlement, and USDT/USD withdrawal. Fees: WeChat 2%, Alipay 2.5%, withdrawal 0.1% ($40 min / $150 max). Fills the gap for users who cannot use domestic Chinese channels or tolerate Stripe's 6%+ fees. Both recommendations share a single disclaimer at the end. The Chinese heading uses "易支付" while the English one keeps "EasyPay". --- docs/PAYMENT.md | 7 ++++++- docs/PAYMENT_CN.md | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md index b66a791c..2735aea3 100644 --- a/docs/PAYMENT.md +++ b/docs/PAYMENT.md @@ -28,7 +28,12 @@ Sub2API has a built-in payment system that enables user self-service top-up with > Alipay/WeChat Pay direct and EasyPay can coexist. Direct channels connect to payment APIs directly with lower fees; EasyPay aggregates through third-party platforms with easier setup. -> **EasyPay Recommendation**: [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`) is recommended as an EasyPay provider (link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it). ZPay supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them. +> **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need: +> +> - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it. +> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Link contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code — feel free to remove it. +> +> Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them. --- diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md index 9d96557f..474700e5 100644 --- a/docs/PAYMENT_CN.md +++ b/docs/PAYMENT_CN.md @@ -28,7 +28,12 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 > 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。 -> **EasyPay 推荐**:个人推荐 [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`)作为 EasyPay 服务商(链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉)。ZPay 支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。 +> **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择: +> +> - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。 +> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。链接含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码,介意可去掉。 +> +> 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。 --- -- GitLab From 6ae1cc8f3f893ea1c6896ee045c8e26ff4b2f16d Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 18 Apr 2026 14:45:25 +0800 Subject: [PATCH 020/265] =?UTF-8?q?docs:=20use=20=E6=98=93=E6=94=AF?= =?UTF-8?q?=E4=BB=98=20in=20Chinese=20coexistence=20note?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/PAYMENT_CN.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md index 474700e5..11325a80 100644 --- a/docs/PAYMENT_CN.md +++ b/docs/PAYMENT_CN.md @@ -26,7 +26,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 | **微信官方** | Native 扫码支付、H5 支付 | 直接对接微信支付 APIv3,移动端优先 H5 | | **Stripe** | 银行卡、支付宝、微信支付、Link 等 | 国际支付,支持多币种 | -> 支付宝官方 / 微信官方与 EasyPay 可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;EasyPay 通过第三方平台聚合,接入门槛更低。 +> 支付宝官方 / 微信官方与易支付可以共存。官方渠道直接对接 API,资金直达商户账户,手续费更低;易支付通过第三方平台聚合,接入门槛更低。 > **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择: > -- GitLab From 0c538a584fbf48414fe3c11b9a3c4ddb2dad96d5 Mon Sep 17 00:00:00 2001 From: erio Date: Sat, 18 Apr 2026 14:48:42 +0800 Subject: [PATCH 021/265] docs: note Kyren Topup $200 account fee waived via referral link --- docs/PAYMENT.md | 2 +- docs/PAYMENT_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/PAYMENT.md b/docs/PAYMENT.md index 2735aea3..755b313a 100644 --- a/docs/PAYMENT.md +++ b/docs/PAYMENT.md @@ -31,7 +31,7 @@ Sub2API has a built-in payment system that enables user self-service top-up with > **EasyPay Provider Recommendations**: Both options below are third-party aggregators compatible with the EasyPay protocol. Pick based on the funding channel and settlement currency you need: > > - **Domestic channel / CNY settlement** — [ZPay](https://z-pay.cn/?uid=23808) (`https://z-pay.cn/?uid=23808`): direct integration with official Alipay / WeChat Pay APIs, fee **1.6%**; funds go straight to the merchant account with **T+1 automatic settlement**. Supports **individual users** (no business license required) with up to 10,000 CNY daily transactions; business-licensed accounts have no limit. Link contains the referral code of [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) original author [@touwaeriol](https://github.com/touwaeriol) — feel free to remove it. -> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Link contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code — feel free to remove it. +> - **International channel / USDT or USD settlement** — [Kyren Topup](https://kyren.top/?code=SUB2API) (`https://kyren.top/?code=SUB2API`): a ready-to-launch global payment stack for AI startups with WeChat Pay and Alipay support, local-currency checkout, and USD settlement. Fees: WeChat 2%, Alipay 2.5%; withdrawal 0.1% (min $40, max $150), settled in **USDT or USD**. No qualification review required — sign up and use immediately, making it the lowest barrier to entry. Withdrawal threshold is relatively high, recommended for users **who do not use domestic Chinese payment channels, cannot tolerate Stripe's 6%+ fees, have high transaction volume, and have USD or USDT channels to receive withdrawn funds**. Kyren Topup charges a $200 account opening fee; signing up via this link (which contains Sub2Api author [@Wei-Shaw](https://github.com/Wei-Shaw)'s referral code) **waives the opening fee**. Feel free to remove it if you prefer. > > Please evaluate the security, reliability, and compliance of any third-party payment provider on your own — this project does not endorse or guarantee any of them. diff --git a/docs/PAYMENT_CN.md b/docs/PAYMENT_CN.md index 11325a80..aca3c866 100644 --- a/docs/PAYMENT_CN.md +++ b/docs/PAYMENT_CN.md @@ -31,7 +31,7 @@ Sub2API 内置支付系统,支持用户自助充值,无需部署独立的支 > **易支付服务商推荐**:以下两家均为兼容易支付协议的第三方聚合支付,按资金通道与结算方式选择: > > - **国内渠道 / 人民币结算** — [ZPay](https://z-pay.cn/?uid=23808)(`https://z-pay.cn/?uid=23808`):支付宝 / 微信官方 API 直连,手续费 **1.6%**;资金直达商家账户,**T+1 自动到账**。支持**个人用户**(无营业执照)每日 1 万元以内交易;拥有营业执照则无限额。链接含 [Sub2ApiPay](https://github.com/touwaeriol/sub2apipay) 原作者 [@touwaeriol](https://github.com/touwaeriol) 的邀请码,介意可去掉。 -> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。链接含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码,介意可去掉。 +> - **国际渠道 / USDT 或美元结算** — [启润支付](https://kyren.top/?code=SUB2API)(`https://kyren.top/?code=SUB2API`):为 AI 项目提供低门槛国际收款通道,支持国际版微信支付与支付宝,本地货币支付、美元结算。手续费:微信 2%、支付宝 2.5%;提现 0.1%(最低 40 美元、最高 150 美元),以 **USDT 或美元**到账。无资质审核、注册即用,使用门槛最低;提现门槛略高,适合**不使用国内支付渠道、无法接受 Stripe 高达 6%+ 手续费、流水较大,且拥有美元或 USDT 渠道可接收提现资金**的用户。启润支付开户费 200 美元,通过本链接注册(含 Sub2Api 作者 [@Wei-Shaw](https://github.com/Wei-Shaw) 邀请码)可**免开户费**,介意可去掉。 > > 支付渠道的安全性、稳定性及合规性请自行鉴别,本项目不对任何第三方支付服务商做担保或背书。 -- GitLab From c3cb0280ef8dff949ad7ad8f84793872b9304686 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 19 Apr 2026 01:40:25 +0800 Subject: [PATCH 022/265] fix(payment): alipay redirect-only flow, H5 detection and popup sizing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The native Alipay provider previously tried to embed the payment page URL into a QR code on the client — the URL is not a scannable payload so the QR never worked. Merchants also hit a H5 detection mismatch whenever the backend UA sniffer missed iPadOS 13+ or embedded browsers, and the popup window was too small for Alipay's standard checkout layout (QR + account-login panel on the right), forcing the user to scroll horizontally and vertically. Changes: Backend - alipay.go: drop QR-on-URL path. Use redirect-only flow — alipay.trade.page.pay for PC (returns a gateway URL the browser opens in a new window) and alipay.trade.wap.pay for H5 (returns a URL the browser jumps to). Both flows produce pages on openapi.alipaydev.com / excashier.alipay.com; the client never renders a QR itself. - payment_handler.go: add optional is_mobile bool to CreateOrderRequest so the frontend can declare the device explicitly. Server still falls back to UA sniffing when absent. Frontend - types/payment.ts, PaymentView.vue: declare is_mobile in CreateOrderRequest and pass the computed isMobileDevice() value. - providerConfig.ts: replace the two fixed POPUP_WINDOW_FEATURES constants with getPaymentPopupFeatures(), which prefers 1250×900 (Alipay's checkout footprint), clamps to window.screen.avail* and centers the popup so it never overflows on smaller laptops. - PaymentQRDialog.vue, PaymentStatusPanel.vue, StripePaymentInline.vue, PaymentView.vue: use the new helper at all popup call sites. --- backend/internal/handler/payment_handler.go | 10 +++- backend/internal/payment/provider/alipay.go | 50 ++++++++++--------- .../components/payment/PaymentQRDialog.vue | 4 +- .../components/payment/PaymentStatusPanel.vue | 4 +- .../payment/StripePaymentInline.vue | 4 +- .../src/components/payment/providerConfig.ts | 23 +++++++-- frontend/src/types/payment.ts | 1 + frontend/src/views/user/PaymentView.vue | 5 +- 8 files changed, 64 insertions(+), 37 deletions(-) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 1ddb8ae2..854dca54 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -206,6 +206,10 @@ type CreateOrderRequest struct { PaymentType string `json:"payment_type" binding:"required"` OrderType string `json:"order_type"` PlanID int64 `json:"plan_id"` + // IsMobile lets the frontend declare its mobile status directly. When + // nil we fall back to User-Agent heuristics (which miss iPadOS / some + // embedded browsers that strip the "Mobile" keyword). + IsMobile *bool `json:"is_mobile,omitempty"` } // CreateOrder creates a new payment order. @@ -222,12 +226,16 @@ func (h *PaymentHandler) CreateOrder(c *gin.Context) { return } + mobile := isMobile(c) + if req.IsMobile != nil { + mobile = *req.IsMobile + } result, err := h.paymentService.CreateOrder(c.Request.Context(), service.CreateOrderRequest{ UserID: subject.UserID, Amount: req.Amount, PaymentType: req.PaymentType, ClientIP: c.ClientIP(), - IsMobile: isMobile(c), + IsMobile: mobile, SrcHost: c.Request.Host, SrcURL: c.Request.Referer(), OrderType: req.OrderType, diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index af8a90c6..fe8ea89c 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -15,8 +15,8 @@ import ( // Alipay product codes. const ( - alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" alipayProductCodeWapPay = "QUICK_WAP_WAY" + alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" ) // Alipay response constants. @@ -79,7 +79,12 @@ func (a *Alipay) SupportedTypes() []payment.PaymentType { return []payment.PaymentType{payment.TypeAlipay} } -// CreatePayment creates an Alipay payment page URL. +// CreatePayment creates an Alipay payment using redirect-only flow: +// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to. +// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a +// new window; Alipay's own page then shows login/QR. We intentionally do +// NOT encode the URL into a QR on the client (it isn't a scannable payload +// and would produce an invalid scan result). func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { client, err := a.getClient() if err != nil { @@ -96,31 +101,31 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque } if req.IsMobile { - return a.createTrade(client, req, notifyURL, returnURL, true) + return a.createWapTrade(client, req, notifyURL, returnURL) } - return a.createTrade(client, req, notifyURL, returnURL, false) + return a.createPagePayTrade(client, req, notifyURL, returnURL) } -func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string, isMobile bool) (*payment.CreatePaymentResponse, error) { - if isMobile { - param := alipay.TradeWapPay{} - param.OutTradeNo = req.OrderID - param.TotalAmount = req.Amount - param.Subject = req.Subject - param.ProductCode = alipayProductCodeWapPay - param.NotifyURL = notifyURL - param.ReturnURL = returnURL - - payURL, err := client.TradeWapPay(param) - if err != nil { - return nil, fmt.Errorf("alipay TradeWapPay: %w", err) - } - return &payment.CreatePaymentResponse{ - TradeNo: req.OrderID, - PayURL: payURL.String(), - }, nil +func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { + param := alipay.TradeWapPay{} + param.OutTradeNo = req.OrderID + param.TotalAmount = req.Amount + param.Subject = req.Subject + param.ProductCode = alipayProductCodeWapPay + param.NotifyURL = notifyURL + param.ReturnURL = returnURL + + payURL, err := client.TradeWapPay(param) + if err != nil { + return nil, fmt.Errorf("alipay TradeWapPay: %w", err) } + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + PayURL: payURL.String(), + }, nil +} +func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { param := alipay.TradePagePay{} param.OutTradeNo = req.OrderID param.TotalAmount = req.Amount @@ -136,7 +141,6 @@ func (a *Alipay) createTrade(client *alipay.Client, req payment.CreatePaymentReq return &payment.CreatePaymentResponse{ TradeNo: req.OrderID, PayURL: payURL.String(), - QRCode: payURL.String(), }, nil } diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue index b9026e78..db90c3b6 100644 --- a/frontend/src/components/payment/PaymentQRDialog.vue +++ b/frontend/src/components/payment/PaymentQRDialog.vue @@ -79,7 +79,7 @@ import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' import { extractApiErrorMessage } from '@/utils/apiError' -import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import QRCode from 'qrcode' import alipayIcon from '@/assets/icons/alipay.svg' @@ -147,7 +147,7 @@ function getLogoForType(): string | null { function reopenPopup() { if (props.payUrl) { - window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures()) } } diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue index 974dee66..17541e59 100644 --- a/frontend/src/components/payment/PaymentStatusPanel.vue +++ b/frontend/src/components/payment/PaymentStatusPanel.vue @@ -125,7 +125,7 @@ import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' import { extractApiErrorMessage } from '@/utils/apiError' -import { POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import Icon from '@/components/icons/Icon.vue' import QRCode from 'qrcode' @@ -194,7 +194,7 @@ const countdownDisplay = computed(() => { function reopenPopup() { if (props.payUrl) { - window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + window.open(props.payUrl, 'paymentPopup', getPaymentPopupFeatures()) } } diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue index b8fd55ef..3ddff8c8 100644 --- a/frontend/src/components/payment/StripePaymentInline.vue +++ b/frontend/src/components/payment/StripePaymentInline.vue @@ -70,7 +70,7 @@ import { useRouter } from 'vue-router' import { extractApiErrorMessage } from '@/utils/apiError' import { paymentAPI } from '@/api/payment' import { useAppStore } from '@/stores' -import { STRIPE_POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { Stripe, StripeElements } from '@stripe/stripe-js' import Icon from '@/components/icons/Icon.vue' @@ -151,7 +151,7 @@ async function handlePay() { amount: String(props.payAmount), }, }).href - const popup = window.open(popupUrl, 'paymentPopup', STRIPE_POPUP_WINDOW_FEATURES) + const popup = window.open(popupUrl, 'paymentPopup', getPaymentPopupFeatures()) const onReady = (event: MessageEvent) => { if (event.source !== popup || event.data?.type !== 'STRIPE_POPUP_READY') return diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts index a83787fd..bf2d4177 100644 --- a/frontend/src/components/payment/providerConfig.ts +++ b/frontend/src/components/payment/providerConfig.ts @@ -43,11 +43,24 @@ export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct', export const PAYMENT_MODE_QRCODE = 'qrcode' export const PAYMENT_MODE_POPUP = 'popup' -/** Window features for payment popup windows */ -export const POPUP_WINDOW_FEATURES = 'width=1000,height=750,left=100,top=80,scrollbars=yes,resizable=yes' - -/** Wider popup for Stripe redirect methods (Alipay checkout page needs ~1200px) */ -export const STRIPE_POPUP_WINDOW_FEATURES = 'width=1250,height=780,left=80,top=60,scrollbars=yes,resizable=yes' +/** Preferred popup size for payment gateways. Alipay's standard checkout + * (QR + account login panel) needs ~1200×900 to render without any scrolling. */ +const PAYMENT_POPUP_PREFERRED_WIDTH = 1250 +const PAYMENT_POPUP_PREFERRED_HEIGHT = 900 + +/** Build a window.open features string sized to fit within the current screen + * while preferring the above dimensions. Centers the popup on the available + * work area so nothing is clipped on smaller laptop displays. */ +export function getPaymentPopupFeatures(): string { + const screen = typeof window !== 'undefined' ? window.screen : null + const availW = screen?.availWidth ?? PAYMENT_POPUP_PREFERRED_WIDTH + const availH = screen?.availHeight ?? PAYMENT_POPUP_PREFERRED_HEIGHT + const width = Math.min(PAYMENT_POPUP_PREFERRED_WIDTH, availW - 40) + const height = Math.min(PAYMENT_POPUP_PREFERRED_HEIGHT, availH - 40) + const left = Math.max(0, Math.floor((availW - width) / 2)) + const top = Math.max(0, Math.floor((availH - height) / 2)) + return `width=${width},height=${height},left=${left},top=${top},scrollbars=yes,resizable=yes` +} /** Webhook paths for each provider (relative to origin). */ export const WEBHOOK_PATHS: Record = { diff --git a/frontend/src/types/payment.ts b/frontend/src/types/payment.ts index 7ecbb9a9..6f2eec51 100644 --- a/frontend/src/types/payment.ts +++ b/frontend/src/types/payment.ts @@ -154,6 +154,7 @@ export interface CreateOrderRequest { payment_type: string order_type: string plan_id?: number + is_mobile?: boolean } export interface CreateOrderResult { diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index e91df5da..3f1401b3 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -277,7 +277,7 @@ import type { SubscriptionPlan, CheckoutInfoResponse, OrderType } from '@/types/ import AppLayout from '@/components/layout/AppLayout.vue' import AmountInput from '@/components/payment/AmountInput.vue' import PaymentMethodSelector from '@/components/payment/PaymentMethodSelector.vue' -import { METHOD_ORDER, POPUP_WINDOW_FEATURES } from '@/components/payment/providerConfig' +import { METHOD_ORDER, getPaymentPopupFeatures } from '@/components/payment/providerConfig' import { platformAccentBarClass, platformBadgeLightClass, platformBadgeClass, platformTextClass, platformLabel } from '@/utils/platformColors' import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue' import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue' @@ -551,9 +551,10 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n payment_type: selectedMethod.value, order_type: orderType, plan_id: planId, + is_mobile: isMobileDevice(), }) const openWindow = (url: string) => { - const win = window.open(url, 'paymentPopup', POPUP_WINDOW_FEATURES) + const win = window.open(url, 'paymentPopup', getPaymentPopupFeatures()) if (!win || win.closed) { window.location.href = url } -- GitLab From 235f710853b3e44b5b2fa236cfefd0b9d0c2a044 Mon Sep 17 00:00:00 2001 From: erio Date: Sun, 19 Apr 2026 01:46:50 +0800 Subject: [PATCH 023/265] feat(payment): redact provider secrets in admin config API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Admin GET /api/v1/admin/payment/providers previously returned every config value — including privateKey / apiV3Key / secretKey etc. — verbatim. Any future XSS on the admin UI would hand attackers the full set of production payment credentials, and the plaintext values sat unnecessarily in browser memory for every operator. Treat those fields as write-only from the admin surface: - decryptAndMaskConfig() strips sensitive keys from the GET response. The authoritative list is an explicit per-provider registry that mirrors the frontend's PROVIDER_CONFIG_FIELDS sensitive flag: alipay → privateKey, publicKey, alipayPublicKey wxpay → privateKey, apiV3Key, publicKey stripe → secretKey, webhookSecret (publishableKey stays plain) easypay → pkey Payment runtime still reads the full config via decryptConfig, so nothing at the gateway changes. - mergeConfig() treats an empty value for a sensitive key as "leave unchanged" — the admin UI omits unchanged secrets so operators can tweak non-sensitive settings without re-entering credentials. - Admin dialog (PaymentProviderDialog.vue): * secret inputs get autocomplete="new-password", data-1p-ignore, data-lpignore and data-bwignore so password managers do not offer to save provider credentials * in edit mode the required-field check skips sensitive fields (empty is the "keep existing" signal) and the placeholder shows "leave empty to keep" instead of the default example value * create mode still requires every non-optional field, including secrets, since there is nothing to preserve - Unit test renamed to TestIsSensitiveProviderConfigField, covers the per-provider registry and specifically asserts that Stripe's publishableKey is NOT treated as a secret. --- .../service/payment_config_providers.go | 78 +++++++++++++++---- .../service/payment_config_providers_test.go | 63 ++++++++------- .../payment/PaymentProviderDialog.vue | 23 ++++-- 3 files changed, 119 insertions(+), 45 deletions(-) diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 8e470525..3d1e4dc4 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -52,7 +52,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte AllowUserRefund: inst.AllowUserRefund, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } - resp.Config, err = s.decryptAndMaskConfig(inst.Config) + resp.Config, err = s.decryptAndMaskConfig(inst.ProviderKey, inst.Config) if err != nil { return nil, fmt.Errorf("decrypt config for instance %d: %w", inst.ID, err) } @@ -61,8 +61,26 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte return result, nil } -func (s *PaymentConfigService) decryptAndMaskConfig(encrypted string) (map[string]string, error) { - return s.decryptConfig(encrypted) +// decryptAndMaskConfig returns the stored config with sensitive fields omitted. +// Admin UIs display masked placeholders for these; the raw values never leave +// the server. Callers that need the full config (e.g. payment runtime) must +// use decryptConfig directly. +func (s *PaymentConfigService) decryptAndMaskConfig(providerKey, encrypted string) (map[string]string, error) { + cfg, err := s.decryptConfig(encrypted) + if err != nil { + return nil, err + } + if cfg == nil { + return nil, nil + } + masked := make(map[string]string, len(cfg)) + for k, v := range cfg { + if isSensitiveProviderConfigField(providerKey, k) { + continue + } + masked[k] = v + } + return masked, nil } // pendingOrderStatuses are order statuses considered "in progress". @@ -72,16 +90,27 @@ var pendingOrderStatuses = []string{ payment.OrderStatusRecharging, } -var sensitiveConfigPatterns = []string{"key", "pkey", "secret", "private", "password"} +// providerSensitiveConfigFields is the authoritative list of config keys that +// are treated as secrets per provider. Must stay in sync with the frontend +// definition at frontend/src/components/payment/providerConfig.ts +// (PROVIDER_CONFIG_FIELDS, fields with sensitive: true). +// +// Key matching is case-insensitive. Non-listed keys (e.g. appId, notifyUrl, +// stripe publishableKey) are returned in plaintext by the admin GET API. +var providerSensitiveConfigFields = map[string]map[string]struct{}{ + payment.TypeEasyPay: {"pkey": {}}, + payment.TypeAlipay: {"privatekey": {}, "publickey": {}, "alipaypublickey": {}}, + payment.TypeWxpay: {"privatekey": {}, "apiv3key": {}, "publickey": {}}, + payment.TypeStripe: {"secretkey": {}, "webhooksecret": {}}, +} -func isSensitiveConfigField(fieldName string) bool { - lower := strings.ToLower(fieldName) - for _, p := range sensitiveConfigPatterns { - if strings.Contains(lower, p) { - return true - } +func isSensitiveProviderConfigField(providerKey, fieldName string) bool { + fields, ok := providerSensitiveConfigFields[providerKey] + if !ok { + return false } - return false + _, found := fields[strings.ToLower(fieldName)] + return found } func (s *PaymentConfigService) countPendingOrders(ctx context.Context, providerInstanceID int64) (int, error) { @@ -137,10 +166,26 @@ func validateProviderRequest(providerKey, name, supportedTypes string) error { // NOTE: This function exceeds 30 lines due to per-field nil-check patch update // boilerplate and pending-order safety checks. func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id int64, req UpdateProviderInstanceRequest) (*dbent.PaymentProviderInstance, error) { + var cachedInst *dbent.PaymentProviderInstance + loadInst := func() (*dbent.PaymentProviderInstance, error) { + if cachedInst != nil { + return cachedInst, nil + } + inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) + if err != nil { + return nil, fmt.Errorf("load provider instance: %w", err) + } + cachedInst = inst + return inst, nil + } if req.Config != nil { + inst, err := loadInst() + if err != nil { + return nil, err + } hasSensitive := false - for k := range req.Config { - if isSensitiveConfigField(k) && req.Config[k] != "" { + for k, v := range req.Config { + if v != "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { hasSensitive = true break } @@ -283,9 +328,14 @@ func (s *PaymentConfigService) mergeConfig(ctx context.Context, id int64, newCon return nil, fmt.Errorf("decrypt existing config for instance %d: %w", id, err) } if existing == nil { - return newConfig, nil + existing = map[string]string{} } for k, v := range newConfig { + // Preserve existing secrets when the client submits an empty value + // (admin UI omits the value to indicate "leave unchanged"). + if v == "" && isSensitiveProviderConfigField(inst.ProviderKey, k) { + continue + } existing[k] = v } return existing, nil diff --git a/backend/internal/service/payment_config_providers_test.go b/backend/internal/service/payment_config_providers_test.go index 2aaa874f..bc2a9b18 100644 --- a/backend/internal/service/payment_config_providers_test.go +++ b/backend/internal/service/payment_config_providers_test.go @@ -97,41 +97,52 @@ func TestValidateProviderRequest(t *testing.T) { } } -func TestIsSensitiveConfigField(t *testing.T) { +func TestIsSensitiveProviderConfigField(t *testing.T) { t.Parallel() tests := []struct { - field string - wantSen bool + providerKey string + field string + wantSen bool }{ - // Sensitive fields (contain key/secret/private/password/pkey patterns) - {"secretKey", true}, - {"apiSecret", true}, - {"pkey", true}, - {"privateKey", true}, - {"apiPassword", true}, - {"appKey", true}, - {"SECRET_TOKEN", true}, - {"PrivateData", true}, - {"PASSWORD", true}, - {"mySecretValue", true}, - - // Non-sensitive fields - {"appId", false}, - {"mchId", false}, - {"apiBase", false}, - {"endpoint", false}, - {"merchantNo", false}, - {"paymentMode", false}, - {"notifyUrl", false}, + // Stripe: publishableKey is public, only secretKey/webhookSecret are secrets + {"stripe", "secretKey", true}, + {"stripe", "webhookSecret", true}, + {"stripe", "SecretKey", true}, // case-insensitive + {"stripe", "publishableKey", false}, + {"stripe", "appId", false}, + + // Alipay + {"alipay", "privateKey", true}, + {"alipay", "publicKey", true}, + {"alipay", "alipayPublicKey", true}, + {"alipay", "appId", false}, + {"alipay", "notifyUrl", false}, + + // Wxpay + {"wxpay", "privateKey", true}, + {"wxpay", "apiV3Key", true}, + {"wxpay", "publicKey", true}, + {"wxpay", "publicKeyId", false}, + {"wxpay", "certSerial", false}, + {"wxpay", "mchId", false}, + + // EasyPay + {"easypay", "pkey", true}, + {"easypay", "pid", false}, + {"easypay", "apiBase", false}, + + // Unknown provider: never sensitive + {"unknown", "secretKey", false}, } for _, tc := range tests { - t.Run(tc.field, func(t *testing.T) { + tc := tc + t.Run(tc.providerKey+"/"+tc.field, func(t *testing.T) { t.Parallel() - got := isSensitiveConfigField(tc.field) - assert.Equal(t, tc.wantSen, got, "isSensitiveConfigField(%q)", tc.field) + got := isSensitiveProviderConfigField(tc.providerKey, tc.field) + assert.Equal(t, tc.wantSen, got, "isSensitiveProviderConfigField(%q, %q)", tc.providerKey, tc.field) }) } } diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue index 10c1bfea..624ddcdd 100644 --- a/frontend/src/components/payment/PaymentProviderDialog.vue +++ b/frontend/src/components/payment/PaymentProviderDialog.vue @@ -88,13 +88,24 @@ v-model="config[field.key]" rows="3" class="input font-mono text-xs" + autocomplete="new-password" + data-1p-ignore + data-lpignore="true" + data-bwignore="true" + spellcheck="false" + :placeholder="editing ? t('admin.accounts.leaveEmptyToKeep') : ''" />
- +
@@ -1391,7 +1391,7 @@ {{ t('admin.settings.oidc.validateIdToken') }}
- +
@@ -3024,7 +3024,7 @@ const form = reactive({ oidc_connect_redirect_url: '', oidc_connect_frontend_redirect_url: '/auth/oidc/callback', oidc_connect_token_auth_method: 'client_secret_post', - oidc_connect_use_pkce: false, + oidc_connect_use_pkce: true, oidc_connect_validate_id_token: true, oidc_connect_allowed_signing_algs: 'RS256,ES256,PS256', oidc_connect_clock_skew_seconds: 120, @@ -3613,8 +3613,8 @@ async function saveSettings() { oidc_connect_redirect_url: form.oidc_connect_redirect_url, oidc_connect_frontend_redirect_url: form.oidc_connect_frontend_redirect_url, oidc_connect_token_auth_method: form.oidc_connect_token_auth_method, - oidc_connect_use_pkce: form.oidc_connect_use_pkce, - oidc_connect_validate_id_token: form.oidc_connect_validate_id_token, + oidc_connect_use_pkce: true, + oidc_connect_validate_id_token: true, oidc_connect_allowed_signing_algs: form.oidc_connect_allowed_signing_algs, oidc_connect_clock_skew_seconds: form.oidc_connect_clock_skew_seconds, oidc_connect_require_email_verified: form.oidc_connect_require_email_verified, -- GitLab From fbd0a2e3c488720025a3408e9db234407d8aef9b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 16:27:23 +0800 Subject: [PATCH 033/265] feat: carry suggested third-party profile through pending oauth --- .../internal/handler/auth_linuxdo_oauth.go | 201 +++++++++---- .../handler/auth_linuxdo_oauth_test.go | 12 +- .../handler/auth_oauth_pending_flow.go | 263 ++++++++++++++++++ .../handler/auth_oauth_pending_flow_test.go | 40 +++ backend/internal/handler/auth_oidc_oauth.go | 47 +++- .../internal/handler/auth_oidc_oauth_test.go | 20 ++ frontend/src/api/auth.ts | 24 +- 7 files changed, 534 insertions(+), 73 deletions(-) create mode 100644 backend/internal/handler/auth_oauth_pending_flow.go create mode 100644 backend/internal/handler/auth_oauth_pending_flow_test.go diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 0c7c2da7..2f182642 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -87,20 +87,25 @@ func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + secureCookie := isRequestHTTPS(c) setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) - codeChallenge := "" - if cfg.UsePKCE { - verifier, err := oauth.GenerateCodeVerifier() - if err != nil { - response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) - return - } - codeChallenge = oauth.GenerateCodeChallenge(verifier) - setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return } + codeChallenge := oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) redirectURI := strings.TrimSpace(cfg.RedirectURL) if redirectURI == "" { @@ -161,14 +166,16 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if redirectTo == "" { redirectTo = linuxDoOAuthDefaultRedirectTo } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } - codeVerifier := "" - if cfg.UsePKCE { - codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) - if codeVerifier == "" { - redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") - return - } + codeVerifier, _ := readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return } redirectURI := strings.TrimSpace(cfg.RedirectURL) @@ -198,7 +205,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + email, username, subject, displayName, avatarURL, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) if err != nil { log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") @@ -215,16 +222,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { - pendingToken, tokenErr := h.authService.CreatePendingOAuthToken(email, username) - if tokenErr != nil { - redirectOAuthError(c, frontendCallback, "login_failed", "service_error", "") + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "login", + Identity: service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + }, + CompletionResponse: map[string]any{ + "error": "invitation_required", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") return } - fragment := url.Values{} - fragment.Set("error", "invitation_required") - fragment.Set("pending_oauth_token", pendingToken) - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + redirectToFrontendCallback(c, frontendCallback) return } // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 @@ -232,18 +255,39 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } - fragment := url.Values{} - fragment.Set("access_token", tokenPair.AccessToken) - fragment.Set("refresh_token", tokenPair.RefreshToken) - fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) - fragment.Set("token_type", "Bearer") - fragment.Set("redirect", redirectTo) - redirectWithFragment(c, frontendCallback, fragment) + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "login", + Identity: service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + }, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) } type completeLinuxDoOAuthRequest struct { - PendingOAuthToken string `json:"pending_oauth_token" binding:"required"` - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -256,9 +300,38 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - email, username, err := h.authService.VerifyPendingOAuthToken(req.PendingOAuthToken) + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "INVALID_TOKEN", "message": "invalid or expired registration token"}) + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) return } @@ -267,6 +340,14 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) c.JSON(http.StatusOK, gin.H{ "access_token": tokenPair.AccessToken, @@ -303,9 +384,7 @@ func linuxDoExchangeCode( form.Set("client_id", cfg.ClientID) form.Set("code", code) form.Set("redirect_uri", redirectURI) - if cfg.UsePKCE { - form.Set("code_verifier", codeVerifier) - } + form.Set("code_verifier", codeVerifier) r := client.R(). SetContext(ctx). @@ -353,11 +432,11 @@ func linuxDoFetchUserInfo( ctx context.Context, cfg config.LinuxDoConnectConfig, token *linuxDoTokenResponse, -) (email string, username string, subject string, err error) { +) (email string, username string, subject string, displayName string, avatarURL string, err error) { client := req.C().SetTimeout(30 * time.Second) authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) if err != nil { - return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + return "", "", "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) } resp, err := client.R(). @@ -366,16 +445,16 @@ func linuxDoFetchUserInfo( SetHeader("Authorization", authorization). Get(cfg.UserInfoURL) if err != nil { - return "", "", "", fmt.Errorf("request userinfo: %w", err) + return "", "", "", "", "", fmt.Errorf("request userinfo: %w", err) } if !resp.IsSuccessState() { - return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + return "", "", "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) } return linuxDoParseUserInfo(resp.String(), cfg) } -func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, displayName string, avatarURL string, err error) { email = firstNonEmpty( getGJSON(body, cfg.UserInfoEmailPath), getGJSON(body, "email"), @@ -400,12 +479,29 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s getGJSON(body, "user.id"), ) + displayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "user.name"), + getGJSON(body, "user.username"), + username, + ) + avatarURL = firstNonEmpty( + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "picture"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) + subject = strings.TrimSpace(subject) if subject == "" { - return "", "", "", errors.New("userinfo missing id field") + return "", "", "", "", "", errors.New("userinfo missing id field") } if !isSafeLinuxDoSubject(subject) { - return "", "", "", errors.New("userinfo returned invalid id field") + return "", "", "", "", "", errors.New("userinfo returned invalid id field") } email = strings.TrimSpace(email) @@ -418,8 +514,13 @@ func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email s if username == "" { username = "linuxdo_" + subject } + displayName = strings.TrimSpace(displayName) + if displayName == "" { + displayName = username + } + avatarURL = strings.TrimSpace(avatarURL) - return email, username, subject, nil + return email, username, subject, displayName, avatarURL, nil } func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { @@ -436,10 +537,8 @@ func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, cod q.Set("scope", cfg.Scopes) } q.Set("state", state) - if cfg.UsePKCE { - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") - } + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") u.RawQuery = q.Encode() return u.String(), nil diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index ff169c52..90bc10d1 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -41,11 +41,13 @@ func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":123,"username":"alice","name":"Alice","avatar_url":"https://cdn.example/avatar.png"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "alice", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "Alice", displayName) + require.Equal(t, "https://cdn.example/avatar.png", avatarURL) } func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { @@ -53,11 +55,13 @@ func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + email, username, subject, displayName, avatarURL, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) require.NoError(t, err) require.Equal(t, "123", subject) require.Equal(t, "linuxdo_123", username) require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) + require.Equal(t, "linuxdo_123", displayName) + require.Equal(t, "", avatarURL) } func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { @@ -65,11 +69,11 @@ func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { UserInfoURL: "https://connect.linux.do/api/user", } - _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + _, _, _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) require.Error(t, err) tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) - _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + _, _, _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) require.Error(t, err) } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go new file mode 100644 index 00000000..a758c0b9 --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -0,0 +1,263 @@ +package handler + +import ( + "net/http" + "net/url" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + oauthPendingBrowserCookiePath = "/api/v1/auth/oauth" + oauthPendingBrowserCookieName = "oauth_pending_browser_session" + oauthPendingSessionCookiePath = "/api/v1/auth/oauth/pending" + oauthPendingSessionCookieName = "oauth_pending_session" + oauthPendingCookieMaxAgeSec = 10 * 60 + + oauthCompletionResponseKey = "completion_response" +) + +type oauthPendingSessionPayload struct { + Intent string + Identity service.PendingAuthIdentityKey + ResolvedEmail string + RedirectTo string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + CompletionResponse map[string]any +} + +func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { + if h == nil || h.authService == nil || h.authService.EntClient() == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + return service.NewAuthPendingIdentityService(h.authService.EntClient()), nil +} + +func generateOAuthPendingBrowserSession() (string, error) { + return oauth.GenerateState() +} + +func setOAuthPendingBrowserCookie(c *gin.Context, sessionKey string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: encodeCookieValue(sessionKey), + Path: oauthPendingBrowserCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingBrowserCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingBrowserCookieName, + Value: "", + Path: oauthPendingBrowserCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingBrowserCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingBrowserCookieName) +} + +func setOAuthPendingSessionCookie(c *gin.Context, sessionToken string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: encodeCookieValue(sessionToken), + Path: oauthPendingSessionCookiePath, + MaxAge: oauthPendingCookieMaxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearOAuthPendingSessionCookie(c *gin.Context, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: oauthPendingSessionCookieName, + Value: "", + Path: oauthPendingSessionCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func readOAuthPendingSessionCookie(c *gin.Context) (string, error) { + return readCookieDecoded(c, oauthPendingSessionCookieName) +} + +func redirectToFrontendCallback(c *gin.Context, frontendCallback string) { + u, err := url.Parse(frontendCallback) + if err != nil { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = "" + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPendingSessionPayload) error { + svc, err := h.pendingIdentityService() + if err != nil { + return err + } + + session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ + Intent: strings.TrimSpace(payload.Intent), + Identity: payload.Identity, + ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), + RedirectTo: strings.TrimSpace(payload.RedirectTo), + BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), + UpstreamIdentityClaims: payload.UpstreamIdentityClaims, + LocalFlowState: map[string]any{ + oauthCompletionResponseKey: payload.CompletionResponse, + }, + }) + if err != nil { + return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err) + } + + setOAuthPendingSessionCookie(c, session.SessionToken, isRequestHTTPS(c)) + return nil +} + +func readCompletionResponse(session map[string]any) (map[string]any, bool) { + if len(session) == 0 { + return nil, false + } + value, ok := session[oauthCompletionResponseKey] + if !ok { + return nil, false + } + result, ok := value.(map[string]any) + if !ok { + return nil, false + } + return result, true +} + +func pendingSessionStringValue(values map[string]any, key string) string { + if len(values) == 0 { + return "" + } + raw, ok := values[key] + if !ok { + return "" + } + value, ok := raw.(string) + if !ok { + return "" + } + return strings.TrimSpace(value) +} + +func pendingSessionWantsInvitation(payload map[string]any) bool { + return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") +} + +func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { + if len(payload) == 0 || len(upstream) == 0 { + return + } + + displayName := pendingSessionStringValue(upstream, "suggested_display_name") + avatarURL := pendingSessionStringValue(upstream, "suggested_avatar_url") + + if displayName != "" { + if _, exists := payload["suggested_display_name"]; !exists { + payload["suggested_display_name"] = displayName + } + } + if avatarURL != "" { + if _, exists := payload["suggested_avatar_url"]; !exists { + payload["suggested_avatar_url"] = avatarURL + } + } + if displayName != "" || avatarURL != "" { + payload["adoption_required"] = true + } +} + +// ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. +// POST /api/v1/auth/oauth/pending/exchange +func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + payload, ok := readCompletionResponse(session.LocalFlowState) + if !ok { + clearCookies() + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_COMPLETION_INVALID", "pending auth completion payload is invalid")) + return + } + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := payload["redirect"]; !exists { + payload["redirect"] = session.RedirectTo + } + } + applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) + + if pendingSessionWantsInvitation(payload) { + response.Success(c, payload) + return + } + + if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + response.Success(c, payload) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go new file mode 100644 index 00000000..5517bae2 --- /dev/null +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -0,0 +1,40 @@ +package handler + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { + payload := map[string]any{ + "access_token": "token", + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Alice", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} + +func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t *testing.T) { + payload := map[string]any{ + "suggested_display_name": "Existing", + "adoption_required": false, + } + upstream := map[string]any{ + "suggested_display_name": "Alice", + "suggested_avatar_url": "https://cdn.example/avatar.png", + } + + applySuggestedProfileToCompletionResponse(payload, upstream) + + require.Equal(t, "Existing", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) + require.Equal(t, true, payload["adoption_required"]) +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 37ef6833..e3694c8f 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -87,6 +87,8 @@ type oidcUserInfoClaims struct { Username string Subject string EmailVerified *bool + DisplayName string + AvatarURL string } type oidcJWKSet struct { @@ -338,12 +340,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, }, CompletionResponse: map[string]any{ "error": "invitation_required", @@ -371,12 +375,14 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, }, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, @@ -643,9 +649,26 @@ func oidcParseUserInfo(body string, cfg config.OIDCConnectConfig) *oidcUserInfoC if verified, ok := getGJSONBool(body, "email_verified"); ok { claims.EmailVerified = &verified } + claims.DisplayName = firstNonEmpty( + getGJSON(body, "name"), + getGJSON(body, "nickname"), + getGJSON(body, "display_name"), + getGJSON(body, "preferred_username"), + getGJSON(body, "username"), + ) + claims.AvatarURL = firstNonEmpty( + getGJSON(body, "picture"), + getGJSON(body, "avatar_url"), + getGJSON(body, "avatar"), + getGJSON(body, "profile_image_url"), + getGJSON(body, "user.avatar"), + getGJSON(body, "user.avatar_url"), + ) claims.Email = strings.TrimSpace(claims.Email) claims.Username = strings.TrimSpace(claims.Username) claims.Subject = strings.TrimSpace(claims.Subject) + claims.DisplayName = strings.TrimSpace(claims.DisplayName) + claims.AvatarURL = strings.TrimSpace(claims.AvatarURL) return claims } diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index a4cf776a..c389db51 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -91,6 +91,26 @@ func TestOIDCParseAndValidateIDToken(t *testing.T) { require.Error(t, err) } +func TestOIDCParseUserInfoIncludesSuggestedProfile(t *testing.T) { + cfg := config.OIDCConnectConfig{} + + claims := oidcParseUserInfo(`{ + "sub":"subject-1", + "preferred_username":"alice", + "name":"Alice Example", + "picture":"https://cdn.example/avatar.png", + "email":"alice@example.com", + "email_verified":true + }`, cfg) + + require.Equal(t, "subject-1", claims.Subject) + require.Equal(t, "alice", claims.Username) + require.Equal(t, "Alice Example", claims.DisplayName) + require.Equal(t, "https://cdn.example/avatar.png", claims.AvatarURL) + require.NotNil(t, claims.EmailVerified) + require.True(t, *claims.EmailVerified) +} + func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { n := base64.RawURLEncoding.EncodeToString(pub.N.Bytes()) e := base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 837c4f4c..d7abcd6a 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -186,6 +186,18 @@ export interface RefreshTokenResponse { token_type: string } +export interface PendingOAuthExchangeResponse { + access_token?: string + refresh_token?: string + expires_in?: number + token_type?: string + redirect?: string + error?: string + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -337,12 +349,10 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { const { data } = await apiClient.post<{ @@ -351,7 +361,6 @@ export async function completeLinuxDoOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/linuxdo/complete-registration', { - pending_oauth_token: pendingOAuthToken, invitation_code: invitationCode }) return data @@ -359,12 +368,10 @@ export async function completeLinuxDoOAuthRegistration( /** * Complete OIDC OAuth registration by supplying an invitation code - * @param pendingOAuthToken - Short-lived JWT from the OAuth callback * @param invitationCode - Invitation code entered by the user * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - pendingOAuthToken: string, invitationCode: string ): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { const { data } = await apiClient.post<{ @@ -373,12 +380,16 @@ export async function completeOIDCOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/oidc/complete-registration', { - pending_oauth_token: pendingOAuthToken, invitation_code: invitationCode }) return data } +export async function exchangePendingOAuthCompletion(): Promise { + const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) + return data +} + export const authAPI = { login, login2FA, @@ -402,6 +413,7 @@ export const authAPI = { resetPassword, refreshToken, revokeAllSessions, + exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, completeOIDCOAuthRegistration } -- GitLab From e9de839d8791151314873bad11ac9583dabd2738 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 17:39:57 +0800 Subject: [PATCH 034/265] feat: rebuild auth identity foundation flow --- backend/ent/authidentity.go | 266 + backend/ent/authidentity/authidentity.go | 209 + backend/ent/authidentity/where.go | 600 + backend/ent/authidentity_create.go | 1036 + backend/ent/authidentity_delete.go | 88 + backend/ent/authidentity_query.go | 797 + backend/ent/authidentity_update.go | 923 + backend/ent/authidentitychannel.go | 228 + .../authidentitychannel.go | 153 + backend/ent/authidentitychannel/where.go | 559 + backend/ent/authidentitychannel_create.go | 932 + backend/ent/authidentitychannel_delete.go | 88 + backend/ent/authidentitychannel_query.go | 643 + backend/ent/authidentitychannel_update.go | 581 + backend/ent/client.go | 876 +- backend/ent/ent.go | 60 +- backend/ent/hook/hook.go | 48 + backend/ent/identityadoptiondecision.go | 223 + .../identityadoptiondecision.go | 159 + backend/ent/identityadoptiondecision/where.go | 342 + .../ent/identityadoptiondecision_create.go | 843 + .../ent/identityadoptiondecision_delete.go | 88 + backend/ent/identityadoptiondecision_query.go | 721 + .../ent/identityadoptiondecision_update.go | 532 + backend/ent/intercept/intercept.go | 120 + backend/ent/migrate/schema.go | 216 + backend/ent/mutation.go | 17621 ++++++++++------ backend/ent/pendingauthsession.go | 399 + .../pendingauthsession/pendingauthsession.go | 279 + backend/ent/pendingauthsession/where.go | 1262 ++ backend/ent/pendingauthsession_create.go | 1815 ++ backend/ent/pendingauthsession_delete.go | 88 + backend/ent/pendingauthsession_query.go | 717 + backend/ent/pendingauthsession_update.go | 1178 ++ backend/ent/predicate/predicate.go | 12 + backend/ent/runtime/runtime.go | 266 +- backend/ent/schema/auth_identity.go | 93 + backend/ent/schema/auth_identity_channel.go | 72 + .../ent/schema/auth_identity_schema_test.go | 124 + .../ent/schema/identity_adoption_decision.go | 70 + backend/ent/schema/pending_auth_session.go | 134 + backend/ent/schema/user.go | 13 + backend/ent/tx.go | 12 + backend/ent/user.go | 79 +- backend/ent/user/user.go | 88 + backend/ent/user/where.go | 226 + backend/ent/user_create.go | 290 + backend/ent/user_query.go | 155 +- backend/ent/user_update.go | 474 + backend/internal/config/config.go | 15 +- .../internal/handler/admin/setting_handler.go | 189 +- ...tting_handler_auth_source_defaults_test.go | 149 + .../internal/handler/auth_linuxdo_oauth.go | 21 +- .../handler/auth_linuxdo_oauth_test.go | 87 + .../handler/auth_oauth_pending_flow.go | 325 + .../handler/auth_oauth_pending_flow_test.go | 457 + backend/internal/handler/auth_oidc_oauth.go | 21 +- .../internal/handler/auth_oidc_oauth_test.go | 84 + backend/internal/handler/auth_wechat_oauth.go | 618 + .../handler/auth_wechat_oauth_test.go | 411 + backend/internal/handler/dto/settings.go | 1 + .../handler/payment_webhook_handler.go | 2 +- .../handler/payment_webhook_handler_test.go | 34 + backend/internal/handler/setting_handler.go | 1 + backend/internal/handler/user_handler.go | 30 +- backend/internal/handler/user_handler_test.go | 136 + backend/internal/repository/api_key_repo.go | 6 + .../auth_identity_migration_report.go | 148 + .../repository/user_profile_identity_repo.go | 544 + ...ser_profile_identity_repo_contract_test.go | 428 + backend/internal/repository/user_repo.go | 44 + .../user_repo_sort_integration_test.go | 83 + backend/internal/server/api_contract_test.go | 4 +- backend/internal/server/routes/auth.go | 14 + .../service/admin_service_apikey_test.go | 9 + .../service/admin_service_delete_test.go | 12 + .../service/auth_pending_identity_service.go | 326 + .../auth_pending_identity_service_test.go | 224 + backend/internal/service/auth_service.go | 164 +- .../auth_service_identity_sync_test.go | 153 + .../auth_service_pending_oauth_test.go | 146 - backend/internal/service/domain_constants.go | 26 + .../service/openai_account_scheduler.go | 77 +- .../service/openai_account_scheduler_test.go | 430 +- ...enai_account_scheduler_ws_snapshot_test.go | 7 +- .../service/payment_config_service.go | 96 +- .../service/payment_config_service_test.go | 186 + .../service/payment_resume_service.go | 248 + .../service/payment_resume_service_test.go | 240 + backend/internal/service/payment_service.go | 40 +- backend/internal/service/setting_service.go | 256 +- ...tting_service_auth_source_defaults_test.go | 136 + backend/internal/service/settings_view.go | 1 + backend/internal/service/user.go | 34 +- backend/internal/service/user_service.go | 153 +- backend/internal/service/user_service_test.go | 180 +- .../108_auth_identity_foundation_core.sql | 141 + .../109_auth_identity_compat_backfill.sql | 125 + ...nding_auth_and_provider_default_grants.sql | 60 + ...11_payment_routing_and_scheduler_flags.sql | 8 + .../api/__tests__/auth-oauth-adoption.spec.ts | 60 + .../settings.authSourceDefaults.spec.ts | 118 + frontend/src/api/admin/settings.ts | 127 + frontend/src/api/auth.ts | 41 +- .../components/auth/WechatOAuthSection.vue | 53 + .../auth/__tests__/WechatOAuthSection.spec.ts | 74 + .../components/payment/PaymentStatusPanel.vue | 29 +- .../payment/__tests__/paymentFlow.spec.ts | 163 + .../src/components/payment/paymentFlow.ts | 197 + .../src/router/__tests__/wechat-route.spec.ts | 55 + frontend/src/router/index.ts | 9 + frontend/src/stores/app.ts | 1 + frontend/src/types/index.ts | 1 + frontend/src/views/admin/SettingsView.vue | 497 +- .../src/views/auth/LinuxDoCallbackView.vue | 259 +- frontend/src/views/auth/LoginView.vue | 10 +- frontend/src/views/auth/OidcCallbackView.vue | 263 +- frontend/src/views/auth/RegisterView.vue | 10 +- .../src/views/auth/WechatCallbackView.vue | 361 + .../__tests__/LinuxDoCallbackView.spec.ts | 180 + .../auth/__tests__/OidcCallbackView.spec.ts | 191 + .../auth/__tests__/WechatCallbackView.spec.ts | 241 + frontend/src/views/user/PaymentView.vue | 229 +- 123 files changed, 40062 insertions(+), 7235 deletions(-) create mode 100644 backend/ent/authidentity.go create mode 100644 backend/ent/authidentity/authidentity.go create mode 100644 backend/ent/authidentity/where.go create mode 100644 backend/ent/authidentity_create.go create mode 100644 backend/ent/authidentity_delete.go create mode 100644 backend/ent/authidentity_query.go create mode 100644 backend/ent/authidentity_update.go create mode 100644 backend/ent/authidentitychannel.go create mode 100644 backend/ent/authidentitychannel/authidentitychannel.go create mode 100644 backend/ent/authidentitychannel/where.go create mode 100644 backend/ent/authidentitychannel_create.go create mode 100644 backend/ent/authidentitychannel_delete.go create mode 100644 backend/ent/authidentitychannel_query.go create mode 100644 backend/ent/authidentitychannel_update.go create mode 100644 backend/ent/identityadoptiondecision.go create mode 100644 backend/ent/identityadoptiondecision/identityadoptiondecision.go create mode 100644 backend/ent/identityadoptiondecision/where.go create mode 100644 backend/ent/identityadoptiondecision_create.go create mode 100644 backend/ent/identityadoptiondecision_delete.go create mode 100644 backend/ent/identityadoptiondecision_query.go create mode 100644 backend/ent/identityadoptiondecision_update.go create mode 100644 backend/ent/pendingauthsession.go create mode 100644 backend/ent/pendingauthsession/pendingauthsession.go create mode 100644 backend/ent/pendingauthsession/where.go create mode 100644 backend/ent/pendingauthsession_create.go create mode 100644 backend/ent/pendingauthsession_delete.go create mode 100644 backend/ent/pendingauthsession_query.go create mode 100644 backend/ent/pendingauthsession_update.go create mode 100644 backend/ent/schema/auth_identity.go create mode 100644 backend/ent/schema/auth_identity_channel.go create mode 100644 backend/ent/schema/auth_identity_schema_test.go create mode 100644 backend/ent/schema/identity_adoption_decision.go create mode 100644 backend/ent/schema/pending_auth_session.go create mode 100644 backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go create mode 100644 backend/internal/handler/auth_wechat_oauth.go create mode 100644 backend/internal/handler/auth_wechat_oauth_test.go create mode 100644 backend/internal/handler/user_handler_test.go create mode 100644 backend/internal/repository/auth_identity_migration_report.go create mode 100644 backend/internal/repository/user_profile_identity_repo.go create mode 100644 backend/internal/repository/user_profile_identity_repo_contract_test.go create mode 100644 backend/internal/service/auth_pending_identity_service.go create mode 100644 backend/internal/service/auth_pending_identity_service_test.go create mode 100644 backend/internal/service/auth_service_identity_sync_test.go delete mode 100644 backend/internal/service/auth_service_pending_oauth_test.go create mode 100644 backend/internal/service/payment_resume_service.go create mode 100644 backend/internal/service/payment_resume_service_test.go create mode 100644 backend/internal/service/setting_service_auth_source_defaults_test.go create mode 100644 backend/migrations/108_auth_identity_foundation_core.sql create mode 100644 backend/migrations/109_auth_identity_compat_backfill.sql create mode 100644 backend/migrations/110_pending_auth_and_provider_default_grants.sql create mode 100644 backend/migrations/111_payment_routing_and_scheduler_flags.sql create mode 100644 frontend/src/api/__tests__/auth-oauth-adoption.spec.ts create mode 100644 frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts create mode 100644 frontend/src/components/auth/WechatOAuthSection.vue create mode 100644 frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts create mode 100644 frontend/src/components/payment/__tests__/paymentFlow.spec.ts create mode 100644 frontend/src/components/payment/paymentFlow.ts create mode 100644 frontend/src/router/__tests__/wechat-route.spec.ts create mode 100644 frontend/src/views/auth/WechatCallbackView.vue create mode 100644 frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts create mode 100644 frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts create mode 100644 frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts diff --git a/backend/ent/authidentity.go b/backend/ent/authidentity.go new file mode 100644 index 00000000..5ccfcf19 --- /dev/null +++ b/backend/ent/authidentity.go @@ -0,0 +1,266 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentity is the model entity for the AuthIdentity schema. +type AuthIdentity struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // UserID holds the value of the "user_id" field. + UserID int64 `json:"user_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // VerifiedAt holds the value of the "verified_at" field. + VerifiedAt *time.Time `json:"verified_at,omitempty"` + // Issuer holds the value of the "issuer" field. + Issuer *string `json:"issuer,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityQuery when eager-loading is set. + Edges AuthIdentityEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // Channels holds the value of the channels edge. + Channels []*AuthIdentityChannel `json:"channels,omitempty"` + // AdoptionDecisions holds the value of the adoption_decisions edge. + AdoptionDecisions []*IdentityAdoptionDecision `json:"adoption_decisions,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [3]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// ChannelsOrErr returns the Channels value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) ChannelsOrErr() ([]*AuthIdentityChannel, error) { + if e.loadedTypes[1] { + return e.Channels, nil + } + return nil, &NotLoadedError{edge: "channels"} +} + +// AdoptionDecisionsOrErr returns the AdoptionDecisions value or an error if the edge +// was not loaded in eager-loading. +func (e AuthIdentityEdges) AdoptionDecisionsOrErr() ([]*IdentityAdoptionDecision, error) { + if e.loadedTypes[2] { + return e.AdoptionDecisions, nil + } + return nil, &NotLoadedError{edge: "adoption_decisions"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentity) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentity.FieldMetadata: + values[i] = new([]byte) + case authidentity.FieldID, authidentity.FieldUserID: + values[i] = new(sql.NullInt64) + case authidentity.FieldProviderType, authidentity.FieldProviderKey, authidentity.FieldProviderSubject, authidentity.FieldIssuer: + values[i] = new(sql.NullString) + case authidentity.FieldCreatedAt, authidentity.FieldUpdatedAt, authidentity.FieldVerifiedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentity fields. +func (_m *AuthIdentity) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentity.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentity.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentity.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentity.FieldUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.Int64 + } + case authidentity.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentity.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentity.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case authidentity.FieldVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field verified_at", values[i]) + } else if value.Valid { + _m.VerifiedAt = new(time.Time) + *_m.VerifiedAt = value.Time + } + case authidentity.FieldIssuer: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field issuer", values[i]) + } else if value.Valid { + _m.Issuer = new(string) + *_m.Issuer = value.String + } + case authidentity.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentity. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentity) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryUser() *UserQuery { + return NewAuthIdentityClient(_m.config).QueryUser(_m) +} + +// QueryChannels queries the "channels" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryChannels() *AuthIdentityChannelQuery { + return NewAuthIdentityClient(_m.config).QueryChannels(_m) +} + +// QueryAdoptionDecisions queries the "adoption_decisions" edge of the AuthIdentity entity. +func (_m *AuthIdentity) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + return NewAuthIdentityClient(_m.config).QueryAdoptionDecisions(_m) +} + +// Update returns a builder for updating this AuthIdentity. +// Note that you need to call AuthIdentity.Unwrap() before calling this method if this AuthIdentity +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentity) Update() *AuthIdentityUpdateOne { + return NewAuthIdentityClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentity entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentity) Unwrap() *AuthIdentity { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentity is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentity) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentity(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(fmt.Sprintf("%v", _m.UserID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.VerifiedAt; v != nil { + builder.WriteString("verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.Issuer; v != nil { + builder.WriteString("issuer=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentities is a parsable slice of AuthIdentity. +type AuthIdentities []*AuthIdentity diff --git a/backend/ent/authidentity/authidentity.go b/backend/ent/authidentity/authidentity.go new file mode 100644 index 00000000..c90be759 --- /dev/null +++ b/backend/ent/authidentity/authidentity.go @@ -0,0 +1,209 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentity type in the database. + Label = "auth_identity" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldVerifiedAt holds the string denoting the verified_at field in the database. + FieldVerifiedAt = "verified_at" + // FieldIssuer holds the string denoting the issuer field in the database. + FieldIssuer = "issuer" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeChannels holds the string denoting the channels edge name in mutations. + EdgeChannels = "channels" + // EdgeAdoptionDecisions holds the string denoting the adoption_decisions edge name in mutations. + EdgeAdoptionDecisions = "adoption_decisions" + // Table holds the table name of the authidentity in the database. + Table = "auth_identities" + // UserTable is the table that holds the user relation/edge. + UserTable = "auth_identities" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // ChannelsTable is the table that holds the channels relation/edge. + ChannelsTable = "auth_identity_channels" + // ChannelsInverseTable is the table name for the AuthIdentityChannel entity. + // It exists in this package in order to avoid circular dependency with the "authidentitychannel" package. + ChannelsInverseTable = "auth_identity_channels" + // ChannelsColumn is the table column denoting the channels relation/edge. + ChannelsColumn = "identity_id" + // AdoptionDecisionsTable is the table that holds the adoption_decisions relation/edge. + AdoptionDecisionsTable = "identity_adoption_decisions" + // AdoptionDecisionsInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionsInverseTable = "identity_adoption_decisions" + // AdoptionDecisionsColumn is the table column denoting the adoption_decisions relation/edge. + AdoptionDecisionsColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentity fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldUserID, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldVerifiedAt, + FieldIssuer, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentity queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByVerifiedAt orders the results by the verified_at field. +func ByVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldVerifiedAt, opts...).ToFunc() +} + +// ByIssuer orders the results by the issuer field. +func ByIssuer(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIssuer, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByChannelsCount orders the results by channels count. +func ByChannelsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newChannelsStep(), opts...) + } +} + +// ByChannels orders the results by channels terms. +func ByChannels(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newChannelsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByAdoptionDecisionsCount orders the results by adoption_decisions count. +func ByAdoptionDecisionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAdoptionDecisionsStep(), opts...) + } +} + +// ByAdoptionDecisions orders the results by adoption_decisions terms. +func ByAdoptionDecisions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newChannelsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ChannelsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) +} +func newAdoptionDecisionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) +} diff --git a/backend/ent/authidentity/where.go b/backend/ent/authidentity/where.go new file mode 100644 index 00000000..3dbf3178 --- /dev/null +++ b/backend/ent/authidentity/where.go @@ -0,0 +1,600 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// VerifiedAt applies equality check predicate on the "verified_at" field. It's identical to VerifiedAtEQ. +func VerifiedAt(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// Issuer applies equality check predicate on the "issuer" field. It's identical to IssuerEQ. +func Issuer(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...int64) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldUserID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// VerifiedAtEQ applies the EQ predicate on the "verified_at" field. +func VerifiedAtEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtNEQ applies the NEQ predicate on the "verified_at" field. +func VerifiedAtNEQ(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldVerifiedAt, v)) +} + +// VerifiedAtIn applies the In predicate on the "verified_at" field. +func VerifiedAtIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtNotIn applies the NotIn predicate on the "verified_at" field. +func VerifiedAtNotIn(vs ...time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldVerifiedAt, vs...)) +} + +// VerifiedAtGT applies the GT predicate on the "verified_at" field. +func VerifiedAtGT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldVerifiedAt, v)) +} + +// VerifiedAtGTE applies the GTE predicate on the "verified_at" field. +func VerifiedAtGTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldVerifiedAt, v)) +} + +// VerifiedAtLT applies the LT predicate on the "verified_at" field. +func VerifiedAtLT(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldVerifiedAt, v)) +} + +// VerifiedAtLTE applies the LTE predicate on the "verified_at" field. +func VerifiedAtLTE(v time.Time) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldVerifiedAt, v)) +} + +// VerifiedAtIsNil applies the IsNil predicate on the "verified_at" field. +func VerifiedAtIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldVerifiedAt)) +} + +// VerifiedAtNotNil applies the NotNil predicate on the "verified_at" field. +func VerifiedAtNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldVerifiedAt)) +} + +// IssuerEQ applies the EQ predicate on the "issuer" field. +func IssuerEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEQ(FieldIssuer, v)) +} + +// IssuerNEQ applies the NEQ predicate on the "issuer" field. +func IssuerNEQ(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNEQ(FieldIssuer, v)) +} + +// IssuerIn applies the In predicate on the "issuer" field. +func IssuerIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIn(FieldIssuer, vs...)) +} + +// IssuerNotIn applies the NotIn predicate on the "issuer" field. +func IssuerNotIn(vs ...string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotIn(FieldIssuer, vs...)) +} + +// IssuerGT applies the GT predicate on the "issuer" field. +func IssuerGT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGT(FieldIssuer, v)) +} + +// IssuerGTE applies the GTE predicate on the "issuer" field. +func IssuerGTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldGTE(FieldIssuer, v)) +} + +// IssuerLT applies the LT predicate on the "issuer" field. +func IssuerLT(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLT(FieldIssuer, v)) +} + +// IssuerLTE applies the LTE predicate on the "issuer" field. +func IssuerLTE(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldLTE(FieldIssuer, v)) +} + +// IssuerContains applies the Contains predicate on the "issuer" field. +func IssuerContains(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContains(FieldIssuer, v)) +} + +// IssuerHasPrefix applies the HasPrefix predicate on the "issuer" field. +func IssuerHasPrefix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasPrefix(FieldIssuer, v)) +} + +// IssuerHasSuffix applies the HasSuffix predicate on the "issuer" field. +func IssuerHasSuffix(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldHasSuffix(FieldIssuer, v)) +} + +// IssuerIsNil applies the IsNil predicate on the "issuer" field. +func IssuerIsNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldIsNull(FieldIssuer)) +} + +// IssuerNotNil applies the NotNil predicate on the "issuer" field. +func IssuerNotNil() predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldNotNull(FieldIssuer)) +} + +// IssuerEqualFold applies the EqualFold predicate on the "issuer" field. +func IssuerEqualFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldEqualFold(FieldIssuer, v)) +} + +// IssuerContainsFold applies the ContainsFold predicate on the "issuer" field. +func IssuerContainsFold(v string) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.FieldContainsFold(FieldIssuer, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasChannels applies the HasEdge predicate on the "channels" edge. +func HasChannels() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ChannelsTable, ChannelsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasChannelsWith applies the HasEdge predicate on the "channels" edge with a given conditions (other predicates). +func HasChannelsWith(preds ...predicate.AuthIdentityChannel) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newChannelsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecisions applies the HasEdge predicate on the "adoption_decisions" edge. +func HasAdoptionDecisions() predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AdoptionDecisionsTable, AdoptionDecisionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionsWith applies the HasEdge predicate on the "adoption_decisions" edge with a given conditions (other predicates). +func HasAdoptionDecisionsWith(preds ...predicate.IdentityAdoptionDecision) predicate.AuthIdentity { + return predicate.AuthIdentity(func(s *sql.Selector) { + step := newAdoptionDecisionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentity) predicate.AuthIdentity { + return predicate.AuthIdentity(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentity_create.go b/backend/ent/authidentity_create.go new file mode 100644 index 00000000..e287705c --- /dev/null +++ b/backend/ent/authidentity_create.go @@ -0,0 +1,1036 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityCreate is the builder for creating a AuthIdentity entity. +type AuthIdentityCreate struct { + config + mutation *AuthIdentityMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityCreate) SetCreatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityCreate) SetUpdatedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetUserID sets the "user_id" field. +func (_c *AuthIdentityCreate) SetUserID(v int64) *AuthIdentityCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityCreate) SetProviderType(v string) *AuthIdentityCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityCreate) SetProviderKey(v string) *AuthIdentityCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *AuthIdentityCreate) SetProviderSubject(v string) *AuthIdentityCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetVerifiedAt sets the "verified_at" field. +func (_c *AuthIdentityCreate) SetVerifiedAt(v time.Time) *AuthIdentityCreate { + _c.mutation.SetVerifiedAt(v) + return _c +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityCreate { + if v != nil { + _c.SetVerifiedAt(*v) + } + return _c +} + +// SetIssuer sets the "issuer" field. +func (_c *AuthIdentityCreate) SetIssuer(v string) *AuthIdentityCreate { + _c.mutation.SetIssuer(v) + return _c +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_c *AuthIdentityCreate) SetNillableIssuer(v *string) *AuthIdentityCreate { + if v != nil { + _c.SetIssuer(*v) + } + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityCreate) SetMetadata(v map[string]interface{}) *AuthIdentityCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetUser sets the "user" edge to the User entity. +func (_c *AuthIdentityCreate) SetUser(v *User) *AuthIdentityCreate { + return _c.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_c *AuthIdentityCreate) AddChannelIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddChannelIDs(ids...) + return _c +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_c *AuthIdentityCreate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_c *AuthIdentityCreate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityCreate { + _c.mutation.AddAdoptionDecisionIDs(ids...) + return _c +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_c *AuthIdentityCreate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_c *AuthIdentityCreate) Mutation() *AuthIdentityMutation { + return _c.mutation +} + +// Save creates the AuthIdentity in the database. +func (_c *AuthIdentityCreate) Save(ctx context.Context) (*AuthIdentity, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityCreate) SaveX(ctx context.Context) *AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentity.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentity.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentity.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentity.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentity.updated_at"`)} + } + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "AuthIdentity.user_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentity.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentity.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "AuthIdentity.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentity.metadata"`)} + } + if len(_c.mutation.UserIDs()) == 0 { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "AuthIdentity.user"`)} + } + return nil +} + +func (_c *AuthIdentityCreate) sqlSave(ctx context.Context) (*AuthIdentity, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityCreate) createSpec() (*AuthIdentity, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentity{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentity.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + _node.VerifiedAt = &value + } + if value, ok := _c.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + _node.Issuer = &value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertOne { + _c.conflict = opts + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreate) OnConflictColumns(columns ...string) *AuthIdentityUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityUpsertOne is the builder for "upsert"-ing + // one AuthIdentity node. + AuthIdentityUpsertOne struct { + create *AuthIdentityCreate + } + + // AuthIdentityUpsert is the "OnConflict" setter. + AuthIdentityUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsert) SetUpdatedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUpdatedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUpdatedAt) + return u +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsert) SetUserID(v int64) *AuthIdentityUpsert { + u.Set(authidentity.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateUserID() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldUserID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsert) SetProviderType(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderType() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsert) SetProviderKey(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderKey() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsert) SetProviderSubject(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateProviderSubject() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldProviderSubject) + return u +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsert) SetVerifiedAt(v time.Time) *AuthIdentityUpsert { + u.Set(authidentity.FieldVerifiedAt, v) + return u +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateVerifiedAt() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldVerifiedAt) + return u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsert) ClearVerifiedAt() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldVerifiedAt) + return u +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsert) SetIssuer(v string) *AuthIdentityUpsert { + u.Set(authidentity.FieldIssuer, v) + return u +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateIssuer() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldIssuer) + return u +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsert) ClearIssuer() *AuthIdentityUpsert { + u.SetNull(authidentity.FieldIssuer) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityUpsert { + u.Set(authidentity.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsert) UpdateMetadata() *AuthIdentityUpsert { + u.SetExcluded(authidentity.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) UpdateNewValues() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertOne) Ignore() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertOne) DoNothing() *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertOne) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUpdatedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertOne) SetUserID(v int64) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateUserID() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertOne) SetProviderType(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderType() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertOne) SetProviderKey(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderKey() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertOne) SetProviderSubject(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateProviderSubject() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertOne) SetVerifiedAt(v time.Time) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertOne) ClearVerifiedAt() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertOne) SetIssuer(v string) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertOne) ClearIssuer() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertOne) UpdateMetadata() *AuthIdentityUpsertOne { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityCreateBulk is the builder for creating many AuthIdentity entities in bulk. +type AuthIdentityCreateBulk struct { + config + err error + builders []*AuthIdentityCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentity entities in the database. +func (_c *AuthIdentityCreateBulk) Save(ctx context.Context) ([]*AuthIdentity, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentity, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) SaveX(ctx context.Context) []*AuthIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentity.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityUpsertBulk { + _c.conflict = opts + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityUpsertBulk{ + create: _c, + } +} + +// AuthIdentityUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentity nodes. +type AuthIdentityUpsertBulk struct { + create *AuthIdentityCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) UpdateNewValues() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentity.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentity.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityUpsertBulk) Ignore() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityUpsertBulk) DoNothing() *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityUpsertBulk) Update(set func(*AuthIdentityUpsert)) *AuthIdentityUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUpdatedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetUserID sets the "user_id" field. +func (u *AuthIdentityUpsertBulk) SetUserID(v int64) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateUserID() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateUserID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityUpsertBulk) SetProviderType(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderType() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityUpsertBulk) SetProviderKey(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderKey() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *AuthIdentityUpsertBulk) SetProviderSubject(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateProviderSubject() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetVerifiedAt sets the "verified_at" field. +func (u *AuthIdentityUpsertBulk) SetVerifiedAt(v time.Time) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetVerifiedAt(v) + }) +} + +// UpdateVerifiedAt sets the "verified_at" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateVerifiedAt() + }) +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (u *AuthIdentityUpsertBulk) ClearVerifiedAt() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearVerifiedAt() + }) +} + +// SetIssuer sets the "issuer" field. +func (u *AuthIdentityUpsertBulk) SetIssuer(v string) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetIssuer(v) + }) +} + +// UpdateIssuer sets the "issuer" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateIssuer() + }) +} + +// ClearIssuer clears the value of the "issuer" field. +func (u *AuthIdentityUpsertBulk) ClearIssuer() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.ClearIssuer() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityUpsertBulk) UpdateMetadata() *AuthIdentityUpsertBulk { + return u.Update(func(s *AuthIdentityUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_delete.go b/backend/ent/authidentity_delete.go new file mode 100644 index 00000000..4f1f6f3c --- /dev/null +++ b/backend/ent/authidentity_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityDelete is the builder for deleting a AuthIdentity entity. +type AuthIdentityDelete struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDelete) Where(ps ...predicate.AuthIdentity) *AuthIdentityDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentity.Table, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityDeleteOne is the builder for deleting a single AuthIdentity entity. +type AuthIdentityDeleteOne struct { + _d *AuthIdentityDelete +} + +// Where appends a list predicates to the AuthIdentityDelete builder. +func (_d *AuthIdentityDeleteOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentity.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentity_query.go b/backend/ent/authidentity_query.go new file mode 100644 index 00000000..ff27ef3c --- /dev/null +++ b/backend/ent/authidentity_query.go @@ -0,0 +1,797 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityQuery is the builder for querying AuthIdentity entities. +type AuthIdentityQuery struct { + config + ctx *QueryContext + order []authidentity.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentity + withUser *UserQuery + withChannels *AuthIdentityChannelQuery + withAdoptionDecisions *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityQuery builder. +func (_q *AuthIdentityQuery) Where(ps ...predicate.AuthIdentity) *AuthIdentityQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityQuery) Limit(limit int) *AuthIdentityQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityQuery) Offset(offset int) *AuthIdentityQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityQuery) Unique(unique bool) *AuthIdentityQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityQuery) Order(o ...authidentity.OrderOption) *AuthIdentityQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryUser chains the current query on the "user" edge. +func (_q *AuthIdentityQuery) QueryUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryChannels chains the current query on the "channels" edge. +func (_q *AuthIdentityQuery) QueryChannels() *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecisions chains the current query on the "adoption_decisions" edge. +func (_q *AuthIdentityQuery) QueryAdoptionDecisions() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentity entity from the query. +// Returns a *NotFoundError when no AuthIdentity was found. +func (_q *AuthIdentityQuery) First(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentity.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstX(ctx context.Context) *AuthIdentity { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentity ID from the query. +// Returns a *NotFoundError when no AuthIdentity ID was found. +func (_q *AuthIdentityQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentity.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentity entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentity entity is found. +// Returns a *NotFoundError when no AuthIdentity entities are found. +func (_q *AuthIdentityQuery) Only(ctx context.Context) (*AuthIdentity, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentity.Label} + default: + return nil, &NotSingularError{authidentity.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyX(ctx context.Context) *AuthIdentity { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentity ID in the query. +// Returns a *NotSingularError when more than one AuthIdentity ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentity.Label} + default: + err = &NotSingularError{authidentity.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentities. +func (_q *AuthIdentityQuery) All(ctx context.Context) ([]*AuthIdentity, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentity, *AuthIdentityQuery]() + return withInterceptors[[]*AuthIdentity](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityQuery) AllX(ctx context.Context) []*AuthIdentity { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentity IDs. +func (_q *AuthIdentityQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentity.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityQuery) Clone() *AuthIdentityQuery { + if _q == nil { + return nil + } + return &AuthIdentityQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentity.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentity{}, _q.predicates...), + withUser: _q.withUser.Clone(), + withChannels: _q.withChannels.Clone(), + withAdoptionDecisions: _q.withAdoptionDecisions.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithUser(opts ...func(*UserQuery)) *AuthIdentityQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withUser = query + return _q +} + +// WithChannels tells the query-builder to eager-load the nodes that are connected to +// the "channels" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithChannels(opts ...func(*AuthIdentityChannelQuery)) *AuthIdentityQuery { + query := (&AuthIdentityChannelClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withChannels = query + return _q +} + +// WithAdoptionDecisions tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decisions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityQuery) WithAdoptionDecisions(opts ...func(*IdentityAdoptionDecisionQuery)) *AuthIdentityQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecisions = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// GroupBy(authidentity.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) GroupBy(field string, fields ...string) *AuthIdentityGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentity.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentity.Query(). +// Select(authidentity.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityQuery) Select(fields ...string) *AuthIdentitySelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentitySelect{AuthIdentityQuery: _q} + sbuild.label = authidentity.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentitySelect configured with the given aggregations. +func (_q *AuthIdentityQuery) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentity.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentity, error) { + var ( + nodes = []*AuthIdentity{} + _spec = _q.querySpec() + loadedTypes = [3]bool{ + _q.withUser != nil, + _q.withChannels != nil, + _q.withAdoptionDecisions != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentity).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentity{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withUser; query != nil { + if err := _q.loadUser(ctx, query, nodes, nil, + func(n *AuthIdentity, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := _q.withChannels; query != nil { + if err := _q.loadChannels(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.Channels = []*AuthIdentityChannel{} }, + func(n *AuthIdentity, e *AuthIdentityChannel) { n.Edges.Channels = append(n.Edges.Channels, e) }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecisions; query != nil { + if err := _q.loadAdoptionDecisions(ctx, query, nodes, + func(n *AuthIdentity) { n.Edges.AdoptionDecisions = []*IdentityAdoptionDecision{} }, + func(n *AuthIdentity, e *IdentityAdoptionDecision) { + n.Edges.AdoptionDecisions = append(n.Edges.AdoptionDecisions, e) + }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentity) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *AuthIdentityQuery) loadChannels(ctx context.Context, query *AuthIdentityChannelQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *AuthIdentityChannel)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentitychannel.FieldIdentityID) + } + query.Where(predicate.AuthIdentityChannel(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.ChannelsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *AuthIdentityQuery) loadAdoptionDecisions(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*AuthIdentity, init func(*AuthIdentity), assign func(*AuthIdentity, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*AuthIdentity) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldIdentityID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(authidentity.AdoptionDecisionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.IdentityID + if fk == nil { + return fmt.Errorf(`foreign-key "identity_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "identity_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *AuthIdentityQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for i := range fields { + if fields[i] != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withUser != nil { + _spec.Node.AddColumnOnce(authidentity.FieldUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentity.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentity.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityQuery) ForShare(opts ...sql.LockOption) *AuthIdentityQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityGroupBy is the group-by builder for AuthIdentity entities. +type AuthIdentityGroupBy struct { + selector + build *AuthIdentityQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityGroupBy) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentitySelect is the builder for selecting fields of AuthIdentity entities. +type AuthIdentitySelect struct { + *AuthIdentityQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentitySelect) Aggregate(fns ...AggregateFunc) *AuthIdentitySelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentitySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityQuery, *AuthIdentitySelect](ctx, _s.AuthIdentityQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentitySelect) sqlScan(ctx context.Context, root *AuthIdentityQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentity_update.go b/backend/ent/authidentity_update.go new file mode 100644 index 00000000..c457470b --- /dev/null +++ b/backend/ent/authidentity_update.go @@ -0,0 +1,923 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// AuthIdentityUpdate is the builder for updating AuthIdentity entities. +type AuthIdentityUpdate struct { + config + hooks []Hook + mutation *AuthIdentityMutation +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdate) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdate) SetUpdatedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdate) SetUserID(v int64) *AuthIdentityUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableUserID(v *int64) *AuthIdentityUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdate) SetProviderType(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderType(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdate) SetProviderKey(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderKey(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdate) SetProviderSubject(v string) *AuthIdentityUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableProviderSubject(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdate) SetVerifiedAt(v time.Time) *AuthIdentityUpdate { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdate { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdate) ClearVerifiedAt() *AuthIdentityUpdate { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdate) SetIssuer(v string) *AuthIdentityUpdate { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdate) SetNillableIssuer(v *string) *AuthIdentityUpdate { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdate) ClearIssuer() *AuthIdentityUpdate { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) SetUser(v *User) *AuthIdentityUpdate { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdate) AddChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdate) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdate) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdate) ClearUser() *AuthIdentityUpdate { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdate) ClearChannels() *AuthIdentityUpdate { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdate) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdate) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdate) ClearAdoptionDecisions() *AuthIdentityUpdate { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdate { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdate) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityUpdateOne is the builder for updating a single AuthIdentity entity. +type AuthIdentityUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *AuthIdentityUpdateOne) SetUserID(v int64) *AuthIdentityUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableUserID(v *int64) *AuthIdentityUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityUpdateOne) SetProviderType(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderType(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityUpdateOne) SetProviderKey(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *AuthIdentityUpdateOne) SetProviderSubject(v string) *AuthIdentityUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableProviderSubject(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetVerifiedAt sets the "verified_at" field. +func (_u *AuthIdentityUpdateOne) SetVerifiedAt(v time.Time) *AuthIdentityUpdateOne { + _u.mutation.SetVerifiedAt(v) + return _u +} + +// SetNillableVerifiedAt sets the "verified_at" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableVerifiedAt(v *time.Time) *AuthIdentityUpdateOne { + if v != nil { + _u.SetVerifiedAt(*v) + } + return _u +} + +// ClearVerifiedAt clears the value of the "verified_at" field. +func (_u *AuthIdentityUpdateOne) ClearVerifiedAt() *AuthIdentityUpdateOne { + _u.mutation.ClearVerifiedAt() + return _u +} + +// SetIssuer sets the "issuer" field. +func (_u *AuthIdentityUpdateOne) SetIssuer(v string) *AuthIdentityUpdateOne { + _u.mutation.SetIssuer(v) + return _u +} + +// SetNillableIssuer sets the "issuer" field if the given value is not nil. +func (_u *AuthIdentityUpdateOne) SetNillableIssuer(v *string) *AuthIdentityUpdateOne { + if v != nil { + _u.SetIssuer(*v) + } + return _u +} + +// ClearIssuer clears the value of the "issuer" field. +func (_u *AuthIdentityUpdateOne) ClearIssuer() *AuthIdentityUpdateOne { + _u.mutation.ClearIssuer() + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetUser sets the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) SetUser(v *User) *AuthIdentityUpdateOne { + return _u.SetUserID(v.ID) +} + +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by IDs. +func (_u *AuthIdentityUpdateOne) AddChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddChannelIDs(ids...) + return _u +} + +// AddChannels adds the "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) AddChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddChannelIDs(ids...) +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.AddAdoptionDecisionIDs(ids...) + return _u +} + +// AddAdoptionDecisions adds the "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) AddAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAdoptionDecisionIDs(ids...) +} + +// Mutation returns the AuthIdentityMutation object of the builder. +func (_u *AuthIdentityUpdateOne) Mutation() *AuthIdentityMutation { + return _u.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (_u *AuthIdentityUpdateOne) ClearUser() *AuthIdentityUpdateOne { + _u.mutation.ClearUser() + return _u +} + +// ClearChannels clears all "channels" edges to the AuthIdentityChannel entity. +func (_u *AuthIdentityUpdateOne) ClearChannels() *AuthIdentityUpdateOne { + _u.mutation.ClearChannels() + return _u +} + +// RemoveChannelIDs removes the "channels" edge to AuthIdentityChannel entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveChannelIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveChannelIDs(ids...) + return _u +} + +// RemoveChannels removes "channels" edges to AuthIdentityChannel entities. +func (_u *AuthIdentityUpdateOne) RemoveChannels(v ...*AuthIdentityChannel) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveChannelIDs(ids...) +} + +// ClearAdoptionDecisions clears all "adoption_decisions" edges to the IdentityAdoptionDecision entity. +func (_u *AuthIdentityUpdateOne) ClearAdoptionDecisions() *AuthIdentityUpdateOne { + _u.mutation.ClearAdoptionDecisions() + return _u +} + +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to IdentityAdoptionDecision entities by IDs. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisionIDs(ids ...int64) *AuthIdentityUpdateOne { + _u.mutation.RemoveAdoptionDecisionIDs(ids...) + return _u +} + +// RemoveAdoptionDecisions removes "adoption_decisions" edges to IdentityAdoptionDecision entities. +func (_u *AuthIdentityUpdateOne) RemoveAdoptionDecisions(v ...*IdentityAdoptionDecision) *AuthIdentityUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAdoptionDecisionIDs(ids...) +} + +// Where appends a list predicates to the AuthIdentityUpdate builder. +func (_u *AuthIdentityUpdateOne) Where(ps ...predicate.AuthIdentity) *AuthIdentityUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityUpdateOne) Select(field string, fields ...string) *AuthIdentityUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentity entity. +func (_u *AuthIdentityUpdateOne) Save(ctx context.Context) (*AuthIdentity, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) SaveX(ctx context.Context) *AuthIdentity { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentity.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentity.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentity.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := authidentity.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentity.provider_subject": %w`, err)} + } + } + if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentity.user"`) + } + return nil +} + +func (_u *AuthIdentityUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentity, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentity.Table, authidentity.Columns, sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentity.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentity.FieldID) + for _, f := range fields { + if !authidentity.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentity.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentity.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentity.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(authidentity.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.VerifiedAt(); ok { + _spec.SetField(authidentity.FieldVerifiedAt, field.TypeTime, value) + } + if _u.mutation.VerifiedAtCleared() { + _spec.ClearField(authidentity.FieldVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.Issuer(); ok { + _spec.SetField(authidentity.FieldIssuer, field.TypeString, value) + } + if _u.mutation.IssuerCleared() { + _spec.ClearField(authidentity.FieldIssuer, field.TypeString) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentity.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentity.UserTable, + Columns: []string{authidentity.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedChannelsIDs(); len(nodes) > 0 && !_u.mutation.ChannelsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.ChannelsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.ChannelsTable, + Columns: []string{authidentity.ChannelsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAdoptionDecisionsIDs(); len(nodes) > 0 && !_u.mutation.AdoptionDecisionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: authidentity.AdoptionDecisionsTable, + Columns: []string{authidentity.AdoptionDecisionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentity{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/authidentitychannel.go b/backend/ent/authidentitychannel.go new file mode 100644 index 00000000..1ff3e5d1 --- /dev/null +++ b/backend/ent/authidentitychannel.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannel is the model entity for the AuthIdentityChannel schema. +type AuthIdentityChannel struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID int64 `json:"identity_id,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // Channel holds the value of the "channel" field. + Channel string `json:"channel,omitempty"` + // ChannelAppID holds the value of the "channel_app_id" field. + ChannelAppID string `json:"channel_app_id,omitempty"` + // ChannelSubject holds the value of the "channel_subject" field. + ChannelSubject string `json:"channel_subject,omitempty"` + // Metadata holds the value of the "metadata" field. + Metadata map[string]interface{} `json:"metadata,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the AuthIdentityChannelQuery when eager-loading is set. + Edges AuthIdentityChannelEdges `json:"edges"` + selectValues sql.SelectValues +} + +// AuthIdentityChannelEdges holds the relations/edges for other nodes in the graph. +type AuthIdentityChannelEdges struct { + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [1]bool +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e AuthIdentityChannelEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthIdentityChannel) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldMetadata: + values[i] = new([]byte) + case authidentitychannel.FieldID, authidentitychannel.FieldIdentityID: + values[i] = new(sql.NullInt64) + case authidentitychannel.FieldProviderType, authidentitychannel.FieldProviderKey, authidentitychannel.FieldChannel, authidentitychannel.FieldChannelAppID, authidentitychannel.FieldChannelSubject: + values[i] = new(sql.NullString) + case authidentitychannel.FieldCreatedAt, authidentitychannel.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthIdentityChannel fields. +func (_m *AuthIdentityChannel) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authidentitychannel.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case authidentitychannel.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authidentitychannel.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case authidentitychannel.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = value.Int64 + } + case authidentitychannel.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case authidentitychannel.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case authidentitychannel.FieldChannel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel", values[i]) + } else if value.Valid { + _m.Channel = value.String + } + case authidentitychannel.FieldChannelAppID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_app_id", values[i]) + } else if value.Valid { + _m.ChannelAppID = value.String + } + case authidentitychannel.FieldChannelSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field channel_subject", values[i]) + } else if value.Valid { + _m.ChannelSubject = value.String + } + case authidentitychannel.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthIdentityChannel. +// This includes values selected through modifiers, order, etc. +func (_m *AuthIdentityChannel) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryIdentity queries the "identity" edge of the AuthIdentityChannel entity. +func (_m *AuthIdentityChannel) QueryIdentity() *AuthIdentityQuery { + return NewAuthIdentityChannelClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this AuthIdentityChannel. +// Note that you need to call AuthIdentityChannel.Unwrap() before calling this method if this AuthIdentityChannel +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthIdentityChannel) Update() *AuthIdentityChannelUpdateOne { + return NewAuthIdentityChannelClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthIdentityChannel entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthIdentityChannel) Unwrap() *AuthIdentityChannel { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: AuthIdentityChannel is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthIdentityChannel) String() string { + var builder strings.Builder + builder.WriteString("AuthIdentityChannel(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", _m.IdentityID)) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("channel=") + builder.WriteString(_m.Channel) + builder.WriteString(", ") + builder.WriteString("channel_app_id=") + builder.WriteString(_m.ChannelAppID) + builder.WriteString(", ") + builder.WriteString("channel_subject=") + builder.WriteString(_m.ChannelSubject) + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", _m.Metadata)) + builder.WriteByte(')') + return builder.String() +} + +// AuthIdentityChannels is a parsable slice of AuthIdentityChannel. +type AuthIdentityChannels []*AuthIdentityChannel diff --git a/backend/ent/authidentitychannel/authidentitychannel.go b/backend/ent/authidentitychannel/authidentitychannel.go new file mode 100644 index 00000000..7dcc98bb --- /dev/null +++ b/backend/ent/authidentitychannel/authidentitychannel.go @@ -0,0 +1,153 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the authidentitychannel type in the database. + Label = "auth_identity_channel" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldChannel holds the string denoting the channel field in the database. + FieldChannel = "channel" + // FieldChannelAppID holds the string denoting the channel_app_id field in the database. + FieldChannelAppID = "channel_app_id" + // FieldChannelSubject holds the string denoting the channel_subject field in the database. + FieldChannelSubject = "channel_subject" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the authidentitychannel in the database. + Table = "auth_identity_channels" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "auth_identity_channels" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for authidentitychannel fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldIdentityID, + FieldProviderType, + FieldProviderKey, + FieldChannel, + FieldChannelAppID, + FieldChannelSubject, + FieldMetadata, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + ChannelValidator func(string) error + // ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + ChannelAppIDValidator func(string) error + // ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + ChannelSubjectValidator func(string) error + // DefaultMetadata holds the default value on creation for the "metadata" field. + DefaultMetadata func() map[string]interface{} +) + +// OrderOption defines the ordering options for the AuthIdentityChannel queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByChannel orders the results by the channel field. +func ByChannel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannel, opts...).ToFunc() +} + +// ByChannelAppID orders the results by the channel_app_id field. +func ByChannelAppID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelAppID, opts...).ToFunc() +} + +// ByChannelSubject orders the results by the channel_subject field. +func ByChannelSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldChannelSubject, opts...).ToFunc() +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/authidentitychannel/where.go b/backend/ent/authidentitychannel/where.go new file mode 100644 index 00000000..827dc384 --- /dev/null +++ b/backend/ent/authidentitychannel/where.go @@ -0,0 +1,559 @@ +// Code generated by ent, DO NOT EDIT. + +package authidentitychannel + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// Channel applies equality check predicate on the "channel" field. It's identical to ChannelEQ. +func Channel(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelAppID applies equality check predicate on the "channel_app_id" field. It's identical to ChannelAppIDEQ. +func ChannelAppID(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelSubject applies equality check predicate on the "channel_subject" field. It's identical to ChannelSubjectEQ. +func ChannelSubject(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ChannelEQ applies the EQ predicate on the "channel" field. +func ChannelEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannel, v)) +} + +// ChannelNEQ applies the NEQ predicate on the "channel" field. +func ChannelNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannel, v)) +} + +// ChannelIn applies the In predicate on the "channel" field. +func ChannelIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannel, vs...)) +} + +// ChannelNotIn applies the NotIn predicate on the "channel" field. +func ChannelNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannel, vs...)) +} + +// ChannelGT applies the GT predicate on the "channel" field. +func ChannelGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannel, v)) +} + +// ChannelGTE applies the GTE predicate on the "channel" field. +func ChannelGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannel, v)) +} + +// ChannelLT applies the LT predicate on the "channel" field. +func ChannelLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannel, v)) +} + +// ChannelLTE applies the LTE predicate on the "channel" field. +func ChannelLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannel, v)) +} + +// ChannelContains applies the Contains predicate on the "channel" field. +func ChannelContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannel, v)) +} + +// ChannelHasPrefix applies the HasPrefix predicate on the "channel" field. +func ChannelHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannel, v)) +} + +// ChannelHasSuffix applies the HasSuffix predicate on the "channel" field. +func ChannelHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannel, v)) +} + +// ChannelEqualFold applies the EqualFold predicate on the "channel" field. +func ChannelEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannel, v)) +} + +// ChannelContainsFold applies the ContainsFold predicate on the "channel" field. +func ChannelContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannel, v)) +} + +// ChannelAppIDEQ applies the EQ predicate on the "channel_app_id" field. +func ChannelAppIDEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDNEQ applies the NEQ predicate on the "channel_app_id" field. +func ChannelAppIDNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelAppID, v)) +} + +// ChannelAppIDIn applies the In predicate on the "channel_app_id" field. +func ChannelAppIDIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDNotIn applies the NotIn predicate on the "channel_app_id" field. +func ChannelAppIDNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelAppID, vs...)) +} + +// ChannelAppIDGT applies the GT predicate on the "channel_app_id" field. +func ChannelAppIDGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelAppID, v)) +} + +// ChannelAppIDGTE applies the GTE predicate on the "channel_app_id" field. +func ChannelAppIDGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelAppID, v)) +} + +// ChannelAppIDLT applies the LT predicate on the "channel_app_id" field. +func ChannelAppIDLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelAppID, v)) +} + +// ChannelAppIDLTE applies the LTE predicate on the "channel_app_id" field. +func ChannelAppIDLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelAppID, v)) +} + +// ChannelAppIDContains applies the Contains predicate on the "channel_app_id" field. +func ChannelAppIDContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelAppID, v)) +} + +// ChannelAppIDHasPrefix applies the HasPrefix predicate on the "channel_app_id" field. +func ChannelAppIDHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelAppID, v)) +} + +// ChannelAppIDHasSuffix applies the HasSuffix predicate on the "channel_app_id" field. +func ChannelAppIDHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelAppID, v)) +} + +// ChannelAppIDEqualFold applies the EqualFold predicate on the "channel_app_id" field. +func ChannelAppIDEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelAppID, v)) +} + +// ChannelAppIDContainsFold applies the ContainsFold predicate on the "channel_app_id" field. +func ChannelAppIDContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelAppID, v)) +} + +// ChannelSubjectEQ applies the EQ predicate on the "channel_subject" field. +func ChannelSubjectEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectNEQ applies the NEQ predicate on the "channel_subject" field. +func ChannelSubjectNEQ(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNEQ(FieldChannelSubject, v)) +} + +// ChannelSubjectIn applies the In predicate on the "channel_subject" field. +func ChannelSubjectIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectNotIn applies the NotIn predicate on the "channel_subject" field. +func ChannelSubjectNotIn(vs ...string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldNotIn(FieldChannelSubject, vs...)) +} + +// ChannelSubjectGT applies the GT predicate on the "channel_subject" field. +func ChannelSubjectGT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGT(FieldChannelSubject, v)) +} + +// ChannelSubjectGTE applies the GTE predicate on the "channel_subject" field. +func ChannelSubjectGTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldGTE(FieldChannelSubject, v)) +} + +// ChannelSubjectLT applies the LT predicate on the "channel_subject" field. +func ChannelSubjectLT(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLT(FieldChannelSubject, v)) +} + +// ChannelSubjectLTE applies the LTE predicate on the "channel_subject" field. +func ChannelSubjectLTE(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldLTE(FieldChannelSubject, v)) +} + +// ChannelSubjectContains applies the Contains predicate on the "channel_subject" field. +func ChannelSubjectContains(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContains(FieldChannelSubject, v)) +} + +// ChannelSubjectHasPrefix applies the HasPrefix predicate on the "channel_subject" field. +func ChannelSubjectHasPrefix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasPrefix(FieldChannelSubject, v)) +} + +// ChannelSubjectHasSuffix applies the HasSuffix predicate on the "channel_subject" field. +func ChannelSubjectHasSuffix(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldHasSuffix(FieldChannelSubject, v)) +} + +// ChannelSubjectEqualFold applies the EqualFold predicate on the "channel_subject" field. +func ChannelSubjectEqualFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldEqualFold(FieldChannelSubject, v)) +} + +// ChannelSubjectContainsFold applies the ContainsFold predicate on the "channel_subject" field. +func ChannelSubjectContainsFold(v string) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.FieldContainsFold(FieldChannelSubject, v)) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthIdentityChannel) predicate.AuthIdentityChannel { + return predicate.AuthIdentityChannel(sql.NotPredicates(p)) +} diff --git a/backend/ent/authidentitychannel_create.go b/backend/ent/authidentitychannel_create.go new file mode 100644 index 00000000..4ce28479 --- /dev/null +++ b/backend/ent/authidentitychannel_create.go @@ -0,0 +1,932 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" +) + +// AuthIdentityChannelCreate is the builder for creating a AuthIdentityChannel entity. +type AuthIdentityChannelCreate struct { + config + mutation *AuthIdentityChannelMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthIdentityChannelCreate) SetCreatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableCreatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *AuthIdentityChannelCreate) SetUpdatedAt(v time.Time) *AuthIdentityChannelCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *AuthIdentityChannelCreate) SetNillableUpdatedAt(v *time.Time) *AuthIdentityChannelCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *AuthIdentityChannelCreate) SetIdentityID(v int64) *AuthIdentityChannelCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *AuthIdentityChannelCreate) SetProviderType(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *AuthIdentityChannelCreate) SetProviderKey(v string) *AuthIdentityChannelCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetChannel sets the "channel" field. +func (_c *AuthIdentityChannelCreate) SetChannel(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannel(v) + return _c +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_c *AuthIdentityChannelCreate) SetChannelAppID(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelAppID(v) + return _c +} + +// SetChannelSubject sets the "channel_subject" field. +func (_c *AuthIdentityChannelCreate) SetChannelSubject(v string) *AuthIdentityChannelCreate { + _c.mutation.SetChannelSubject(v) + return _c +} + +// SetMetadata sets the "metadata" field. +func (_c *AuthIdentityChannelCreate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelCreate { + _c.mutation.SetMetadata(v) + return _c +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *AuthIdentityChannelCreate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_c *AuthIdentityChannelCreate) Mutation() *AuthIdentityChannelMutation { + return _c.mutation +} + +// Save creates the AuthIdentityChannel in the database. +func (_c *AuthIdentityChannelCreate) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthIdentityChannelCreate) SaveX(ctx context.Context) *AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthIdentityChannelCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := authidentitychannel.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := authidentitychannel.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Metadata(); !ok { + v := authidentitychannel.DefaultMetadata() + _c.mutation.SetMetadata(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthIdentityChannelCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "AuthIdentityChannel.updated_at"`)} + } + if _, ok := _c.mutation.IdentityID(); !ok { + return &ValidationError{Name: "identity_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.identity_id"`)} + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "AuthIdentityChannel.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.Channel(); !ok { + return &ValidationError{Name: "channel", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel"`)} + } + if v, ok := _c.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelAppID(); !ok { + return &ValidationError{Name: "channel_app_id", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_app_id"`)} + } + if v, ok := _c.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if _, ok := _c.mutation.ChannelSubject(); !ok { + return &ValidationError{Name: "channel_subject", err: errors.New(`ent: missing required field "AuthIdentityChannel.channel_subject"`)} + } + if v, ok := _c.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _, ok := _c.mutation.Metadata(); !ok { + return &ValidationError{Name: "metadata", err: errors.New(`ent: missing required field "AuthIdentityChannel.metadata"`)} + } + if len(_c.mutation.IdentityIDs()) == 0 { + return &ValidationError{Name: "identity", err: errors.New(`ent: missing required edge "AuthIdentityChannel.identity"`)} + } + return nil +} + +func (_c *AuthIdentityChannelCreate) sqlSave(ctx context.Context) (*AuthIdentityChannel, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthIdentityChannelCreate) createSpec() (*AuthIdentityChannel, *sqlgraph.CreateSpec) { + var ( + _node = &AuthIdentityChannel{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authidentitychannel.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + _node.Channel = value + } + if value, ok := _c.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + _node.ChannelAppID = value + } + if value, ok := _c.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + _node.ChannelSubject = value + } + if value, ok := _c.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertOne { + _c.conflict = opts + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreate) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertOne{ + create: _c, + } +} + +type ( + // AuthIdentityChannelUpsertOne is the builder for "upsert"-ing + // one AuthIdentityChannel node. + AuthIdentityChannelUpsertOne struct { + create *AuthIdentityChannelCreate + } + + // AuthIdentityChannelUpsert is the "OnConflict" setter. + AuthIdentityChannelUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsert) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateUpdatedAt() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldUpdatedAt) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsert) SetIdentityID(v int64) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateIdentityID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldIdentityID) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsert) SetProviderType(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderType() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsert) SetProviderKey(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateProviderKey() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldProviderKey) + return u +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsert) SetChannel(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannel, v) + return u +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannel() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannel) + return u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsert) SetChannelAppID(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelAppID, v) + return u +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelAppID() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelAppID) + return u +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsert) SetChannelSubject(v string) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldChannelSubject, v) + return u +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateChannelSubject() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldChannelSubject) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsert) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsert { + u.Set(authidentitychannel.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsert) UpdateMetadata() *AuthIdentityChannelUpsert { + u.SetExcluded(authidentitychannel.FieldMetadata) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) UpdateNewValues() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertOne) Ignore() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertOne) DoNothing() *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreate.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertOne) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateUpdatedAt() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertOne) SetIdentityID(v int64) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateIdentityID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderType(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderType() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertOne) SetProviderKey(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateProviderKey() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertOne) SetChannel(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannel() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelAppID(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelAppID() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertOne) SetChannelSubject(v string) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateChannelSubject() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertOne) UpdateMetadata() *AuthIdentityChannelUpsertOne { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *AuthIdentityChannelUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// AuthIdentityChannelCreateBulk is the builder for creating many AuthIdentityChannel entities in bulk. +type AuthIdentityChannelCreateBulk struct { + config + err error + builders []*AuthIdentityChannelCreate + conflict []sql.ConflictOption +} + +// Save creates the AuthIdentityChannel entities in the database. +func (_c *AuthIdentityChannelCreateBulk) Save(ctx context.Context) ([]*AuthIdentityChannel, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthIdentityChannel, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthIdentityChannelMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) SaveX(ctx context.Context) []*AuthIdentityChannel { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthIdentityChannelCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthIdentityChannelCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.AuthIdentityChannel.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.AuthIdentityChannelUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflict(opts ...sql.ConflictOption) *AuthIdentityChannelUpsertBulk { + _c.conflict = opts + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *AuthIdentityChannelCreateBulk) OnConflictColumns(columns ...string) *AuthIdentityChannelUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &AuthIdentityChannelUpsertBulk{ + create: _c, + } +} + +// AuthIdentityChannelUpsertBulk is the builder for "upsert"-ing +// a bulk of AuthIdentityChannel nodes. +type AuthIdentityChannelUpsertBulk struct { + create *AuthIdentityChannelCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) UpdateNewValues() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(authidentitychannel.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.AuthIdentityChannel.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *AuthIdentityChannelUpsertBulk) Ignore() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *AuthIdentityChannelUpsertBulk) DoNothing() *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the AuthIdentityChannelCreateBulk.OnConflict +// documentation for more info. +func (u *AuthIdentityChannelUpsertBulk) Update(set func(*AuthIdentityChannelUpsert)) *AuthIdentityChannelUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&AuthIdentityChannelUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *AuthIdentityChannelUpsertBulk) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateUpdatedAt() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetIdentityID(v int64) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateIdentityID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateIdentityID() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderType(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderType() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *AuthIdentityChannelUpsertBulk) SetProviderKey(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateProviderKey() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateProviderKey() + }) +} + +// SetChannel sets the "channel" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannel(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannel(v) + }) +} + +// UpdateChannel sets the "channel" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannel() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannel() + }) +} + +// SetChannelAppID sets the "channel_app_id" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelAppID(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelAppID(v) + }) +} + +// UpdateChannelAppID sets the "channel_app_id" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelAppID() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelAppID() + }) +} + +// SetChannelSubject sets the "channel_subject" field. +func (u *AuthIdentityChannelUpsertBulk) SetChannelSubject(v string) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetChannelSubject(v) + }) +} + +// UpdateChannelSubject sets the "channel_subject" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateChannelSubject() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateChannelSubject() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *AuthIdentityChannelUpsertBulk) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *AuthIdentityChannelUpsertBulk) UpdateMetadata() *AuthIdentityChannelUpsertBulk { + return u.Update(func(s *AuthIdentityChannelUpsert) { + s.UpdateMetadata() + }) +} + +// Exec executes the query. +func (u *AuthIdentityChannelUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the AuthIdentityChannelCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for AuthIdentityChannelCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *AuthIdentityChannelUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_delete.go b/backend/ent/authidentitychannel_delete.go new file mode 100644 index 00000000..1a4acac5 --- /dev/null +++ b/backend/ent/authidentitychannel_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelDelete is the builder for deleting a AuthIdentityChannel entity. +type AuthIdentityChannelDelete struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDelete) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthIdentityChannelDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthIdentityChannelDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authidentitychannel.Table, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthIdentityChannelDeleteOne is the builder for deleting a single AuthIdentityChannel entity. +type AuthIdentityChannelDeleteOne struct { + _d *AuthIdentityChannelDelete +} + +// Where appends a list predicates to the AuthIdentityChannelDelete builder. +func (_d *AuthIdentityChannelDeleteOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthIdentityChannelDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authidentitychannel.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthIdentityChannelDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/authidentitychannel_query.go b/backend/ent/authidentitychannel_query.go new file mode 100644 index 00000000..7a202b7f --- /dev/null +++ b/backend/ent/authidentitychannel_query.go @@ -0,0 +1,643 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelQuery is the builder for querying AuthIdentityChannel entities. +type AuthIdentityChannelQuery struct { + config + ctx *QueryContext + order []authidentitychannel.OrderOption + inters []Interceptor + predicates []predicate.AuthIdentityChannel + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthIdentityChannelQuery builder. +func (_q *AuthIdentityChannelQuery) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthIdentityChannelQuery) Limit(limit int) *AuthIdentityChannelQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthIdentityChannelQuery) Offset(offset int) *AuthIdentityChannelQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthIdentityChannelQuery) Unique(unique bool) *AuthIdentityChannelQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthIdentityChannelQuery) Order(o ...authidentitychannel.OrderOption) *AuthIdentityChannelQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *AuthIdentityChannelQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first AuthIdentityChannel entity from the query. +// Returns a *NotFoundError when no AuthIdentityChannel was found. +func (_q *AuthIdentityChannelQuery) First(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authidentitychannel.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthIdentityChannel ID from the query. +// Returns a *NotFoundError when no AuthIdentityChannel ID was found. +func (_q *AuthIdentityChannelQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authidentitychannel.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthIdentityChannel entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthIdentityChannel entity is found. +// Returns a *NotFoundError when no AuthIdentityChannel entities are found. +func (_q *AuthIdentityChannelQuery) Only(ctx context.Context) (*AuthIdentityChannel, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authidentitychannel.Label} + default: + return nil, &NotSingularError{authidentitychannel.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyX(ctx context.Context) *AuthIdentityChannel { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthIdentityChannel ID in the query. +// Returns a *NotSingularError when more than one AuthIdentityChannel ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthIdentityChannelQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authidentitychannel.Label} + default: + err = &NotSingularError{authidentitychannel.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthIdentityChannels. +func (_q *AuthIdentityChannelQuery) All(ctx context.Context) ([]*AuthIdentityChannel, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthIdentityChannel, *AuthIdentityChannelQuery]() + return withInterceptors[[]*AuthIdentityChannel](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) AllX(ctx context.Context) []*AuthIdentityChannel { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthIdentityChannel IDs. +func (_q *AuthIdentityChannelQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authidentitychannel.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthIdentityChannelQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthIdentityChannelQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthIdentityChannelQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthIdentityChannelQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthIdentityChannelQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthIdentityChannelQuery) Clone() *AuthIdentityChannelQuery { + if _q == nil { + return nil + } + return &AuthIdentityChannelQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authidentitychannel.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthIdentityChannel{}, _q.predicates...), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *AuthIdentityChannelQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *AuthIdentityChannelQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// GroupBy(authidentitychannel.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) GroupBy(field string, fields ...string) *AuthIdentityChannelGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthIdentityChannelGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authidentitychannel.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.AuthIdentityChannel.Query(). +// Select(authidentitychannel.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *AuthIdentityChannelQuery) Select(fields ...string) *AuthIdentityChannelSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthIdentityChannelSelect{AuthIdentityChannelQuery: _q} + sbuild.label = authidentitychannel.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthIdentityChannelSelect configured with the given aggregations. +func (_q *AuthIdentityChannelQuery) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthIdentityChannelQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authidentitychannel.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthIdentityChannel, error) { + var ( + nodes = []*AuthIdentityChannel{} + _spec = _q.querySpec() + loadedTypes = [1]bool{ + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthIdentityChannel).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthIdentityChannel{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *AuthIdentityChannel, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *AuthIdentityChannelQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*AuthIdentityChannel, init func(*AuthIdentityChannel), assign func(*AuthIdentityChannel, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*AuthIdentityChannel) + for i := range nodes { + fk := nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *AuthIdentityChannelQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthIdentityChannelQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for i := range fields { + if fields[i] != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(authidentitychannel.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthIdentityChannelQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authidentitychannel.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authidentitychannel.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *AuthIdentityChannelQuery) ForUpdate(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *AuthIdentityChannelQuery) ForShare(opts ...sql.LockOption) *AuthIdentityChannelQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// AuthIdentityChannelGroupBy is the group-by builder for AuthIdentityChannel entities. +type AuthIdentityChannelGroupBy struct { + selector + build *AuthIdentityChannelQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthIdentityChannelGroupBy) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthIdentityChannelGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthIdentityChannelGroupBy) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthIdentityChannelSelect is the builder for selecting fields of AuthIdentityChannel entities. +type AuthIdentityChannelSelect struct { + *AuthIdentityChannelQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthIdentityChannelSelect) Aggregate(fns ...AggregateFunc) *AuthIdentityChannelSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthIdentityChannelSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthIdentityChannelQuery, *AuthIdentityChannelSelect](ctx, _s.AuthIdentityChannelQuery, _s, _s.inters, v) +} + +func (_s *AuthIdentityChannelSelect) sqlScan(ctx context.Context, root *AuthIdentityChannelQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/authidentitychannel_update.go b/backend/ent/authidentitychannel_update.go new file mode 100644 index 00000000..b550c454 --- /dev/null +++ b/backend/ent/authidentitychannel_update.go @@ -0,0 +1,581 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// AuthIdentityChannelUpdate is the builder for updating AuthIdentityChannel entities. +type AuthIdentityChannelUpdate struct { + config + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdate) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdate) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdate) SetIdentityID(v int64) *AuthIdentityChannelUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdate) SetProviderType(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderType(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdate) SetProviderKey(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdate) SetChannel(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannel(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdate) SetChannelAppID(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdate) SetChannelSubject(v string) *AuthIdentityChannelUpdate { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdate) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdate { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdate) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdate { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdate) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdate) ClearIdentity() *AuthIdentityChannelUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthIdentityChannelUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthIdentityChannelUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdate) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthIdentityChannelUpdateOne is the builder for updating a single AuthIdentityChannel entity. +type AuthIdentityChannelUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthIdentityChannelMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *AuthIdentityChannelUpdateOne) SetUpdatedAt(v time.Time) *AuthIdentityChannelUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetIdentityID(v int64) *AuthIdentityChannelUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableIdentityID(v *int64) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderType(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderType(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *AuthIdentityChannelUpdateOne) SetProviderKey(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableProviderKey(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetChannel sets the "channel" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannel(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannel(v) + return _u +} + +// SetNillableChannel sets the "channel" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannel(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannel(*v) + } + return _u +} + +// SetChannelAppID sets the "channel_app_id" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelAppID(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelAppID(v) + return _u +} + +// SetNillableChannelAppID sets the "channel_app_id" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelAppID(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelAppID(*v) + } + return _u +} + +// SetChannelSubject sets the "channel_subject" field. +func (_u *AuthIdentityChannelUpdateOne) SetChannelSubject(v string) *AuthIdentityChannelUpdateOne { + _u.mutation.SetChannelSubject(v) + return _u +} + +// SetNillableChannelSubject sets the "channel_subject" field if the given value is not nil. +func (_u *AuthIdentityChannelUpdateOne) SetNillableChannelSubject(v *string) *AuthIdentityChannelUpdateOne { + if v != nil { + _u.SetChannelSubject(*v) + } + return _u +} + +// SetMetadata sets the "metadata" field. +func (_u *AuthIdentityChannelUpdateOne) SetMetadata(v map[string]interface{}) *AuthIdentityChannelUpdateOne { + _u.mutation.SetMetadata(v) + return _u +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) SetIdentity(v *AuthIdentity) *AuthIdentityChannelUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the AuthIdentityChannelMutation object of the builder. +func (_u *AuthIdentityChannelUpdateOne) Mutation() *AuthIdentityChannelMutation { + return _u.mutation +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *AuthIdentityChannelUpdateOne) ClearIdentity() *AuthIdentityChannelUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the AuthIdentityChannelUpdate builder. +func (_u *AuthIdentityChannelUpdateOne) Where(ps ...predicate.AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthIdentityChannelUpdateOne) Select(field string, fields ...string) *AuthIdentityChannelUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthIdentityChannel entity. +func (_u *AuthIdentityChannelUpdateOne) Save(ctx context.Context) (*AuthIdentityChannel, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) SaveX(ctx context.Context) *AuthIdentityChannel { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthIdentityChannelUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthIdentityChannelUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *AuthIdentityChannelUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := authidentitychannel.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *AuthIdentityChannelUpdateOne) check() error { + if v, ok := _u.mutation.ProviderType(); ok { + if err := authidentitychannel.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := authidentitychannel.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.Channel(); ok { + if err := authidentitychannel.ChannelValidator(v); err != nil { + return &ValidationError{Name: "channel", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelAppID(); ok { + if err := authidentitychannel.ChannelAppIDValidator(v); err != nil { + return &ValidationError{Name: "channel_app_id", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_app_id": %w`, err)} + } + } + if v, ok := _u.mutation.ChannelSubject(); ok { + if err := authidentitychannel.ChannelSubjectValidator(v); err != nil { + return &ValidationError{Name: "channel_subject", err: fmt.Errorf(`ent: validator failed for field "AuthIdentityChannel.channel_subject": %w`, err)} + } + } + if _u.mutation.IdentityCleared() && len(_u.mutation.IdentityIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "AuthIdentityChannel.identity"`) + } + return nil +} + +func (_u *AuthIdentityChannelUpdateOne) sqlSave(ctx context.Context) (_node *AuthIdentityChannel, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(authidentitychannel.Table, authidentitychannel.Columns, sqlgraph.NewFieldSpec(authidentitychannel.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AuthIdentityChannel.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authidentitychannel.FieldID) + for _, f := range fields { + if !authidentitychannel.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != authidentitychannel.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(authidentitychannel.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(authidentitychannel.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(authidentitychannel.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.Channel(); ok { + _spec.SetField(authidentitychannel.FieldChannel, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelAppID(); ok { + _spec.SetField(authidentitychannel.FieldChannelAppID, field.TypeString, value) + } + if value, ok := _u.mutation.ChannelSubject(); ok { + _spec.SetField(authidentitychannel.FieldChannelSubject, field.TypeString, value) + } + if value, ok := _u.mutation.Metadata(); ok { + _spec.SetField(authidentitychannel.FieldMetadata, field.TypeJSON, value) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: authidentitychannel.IdentityTable, + Columns: []string{authidentitychannel.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &AuthIdentityChannel{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authidentitychannel.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/client.go b/backend/ent/client.go index e52e015a..b02f519b 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,12 +20,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -60,18 +64,26 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -118,12 +130,16 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.AuthIdentity = NewAuthIdentityClient(c.config) + c.AuthIdentityChannel = NewAuthIdentityChannelClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) + c.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(c.config) c.PaymentAuditLog = NewPaymentAuditLogClient(c.config) c.PaymentOrder = NewPaymentOrderClient(c.config) c.PaymentProviderInstance = NewPaymentProviderInstanceClient(c.config) + c.PendingAuthSession = NewPendingAuthSessionClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) @@ -229,34 +245,38 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { cfg := c.config cfg.driver = tx return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -274,34 +294,38 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) cfg := c.config cfg.driver = &txDriver{tx: tx, drv: c.driver} return &Tx{ - ctx: ctx, - config: cfg, - APIKey: NewAPIKeyClient(cfg), - Account: NewAccountClient(cfg), - AccountGroup: NewAccountGroupClient(cfg), - Announcement: NewAnnouncementClient(cfg), - AnnouncementRead: NewAnnouncementReadClient(cfg), - ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), - Group: NewGroupClient(cfg), - IdempotencyRecord: NewIdempotencyRecordClient(cfg), - PaymentAuditLog: NewPaymentAuditLogClient(cfg), - PaymentOrder: NewPaymentOrderClient(cfg), - PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), - PromoCode: NewPromoCodeClient(cfg), - PromoCodeUsage: NewPromoCodeUsageClient(cfg), - Proxy: NewProxyClient(cfg), - RedeemCode: NewRedeemCodeClient(cfg), - SecuritySecret: NewSecuritySecretClient(cfg), - Setting: NewSettingClient(cfg), - SubscriptionPlan: NewSubscriptionPlanClient(cfg), - TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), - UsageCleanupTask: NewUsageCleanupTaskClient(cfg), - UsageLog: NewUsageLogClient(cfg), - User: NewUserClient(cfg), - UserAllowedGroup: NewUserAllowedGroupClient(cfg), - UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), - UserAttributeValue: NewUserAttributeValueClient(cfg), - UserSubscription: NewUserSubscriptionClient(cfg), + ctx: ctx, + config: cfg, + APIKey: NewAPIKeyClient(cfg), + Account: NewAccountClient(cfg), + AccountGroup: NewAccountGroupClient(cfg), + Announcement: NewAnnouncementClient(cfg), + AnnouncementRead: NewAnnouncementReadClient(cfg), + AuthIdentity: NewAuthIdentityClient(cfg), + AuthIdentityChannel: NewAuthIdentityChannelClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), + Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), + IdentityAdoptionDecision: NewIdentityAdoptionDecisionClient(cfg), + PaymentAuditLog: NewPaymentAuditLogClient(cfg), + PaymentOrder: NewPaymentOrderClient(cfg), + PaymentProviderInstance: NewPaymentProviderInstanceClient(cfg), + PendingAuthSession: NewPendingAuthSessionClient(cfg), + PromoCode: NewPromoCodeClient(cfg), + PromoCodeUsage: NewPromoCodeUsageClient(cfg), + Proxy: NewProxyClient(cfg), + RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), + Setting: NewSettingClient(cfg), + SubscriptionPlan: NewSubscriptionPlanClient(cfg), + TLSFingerprintProfile: NewTLSFingerprintProfileClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), + UsageLog: NewUsageLogClient(cfg), + User: NewUserClient(cfg), + UserAllowedGroup: NewUserAllowedGroupClient(cfg), + UserAttributeDefinition: NewUserAttributeDefinitionClient(cfg), + UserAttributeValue: NewUserAttributeValueClient(cfg), + UserSubscription: NewUserSubscriptionClient(cfg), }, nil } @@ -332,11 +356,12 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -348,11 +373,12 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PaymentAuditLog, - c.PaymentOrder, c.PaymentProviderInstance, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, c.SubscriptionPlan, - c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.AuthIdentity, c.AuthIdentityChannel, c.ErrorPassthroughRule, c.Group, + c.IdempotencyRecord, c.IdentityAdoptionDecision, c.PaymentAuditLog, + c.PaymentOrder, c.PaymentProviderInstance, c.PendingAuthSession, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.SubscriptionPlan, c.TLSFingerprintProfile, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -372,18 +398,26 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *AuthIdentityMutation: + return c.AuthIdentity.mutate(ctx, m) + case *AuthIdentityChannelMutation: + return c.AuthIdentityChannel.mutate(ctx, m) case *ErrorPassthroughRuleMutation: return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *IdempotencyRecordMutation: return c.IdempotencyRecord.mutate(ctx, m) + case *IdentityAdoptionDecisionMutation: + return c.IdentityAdoptionDecision.mutate(ctx, m) case *PaymentAuditLogMutation: return c.PaymentAuditLog.mutate(ctx, m) case *PaymentOrderMutation: return c.PaymentOrder.mutate(ctx, m) case *PaymentProviderInstanceMutation: return c.PaymentProviderInstance.mutate(ctx, m) + case *PendingAuthSessionMutation: + return c.PendingAuthSession.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -1231,6 +1265,336 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// AuthIdentityClient is a client for the AuthIdentity schema. +type AuthIdentityClient struct { + config +} + +// NewAuthIdentityClient returns a client for the AuthIdentity from the given config. +func NewAuthIdentityClient(c config) *AuthIdentityClient { + return &AuthIdentityClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentity.Hooks(f(g(h())))`. +func (c *AuthIdentityClient) Use(hooks ...Hook) { + c.hooks.AuthIdentity = append(c.hooks.AuthIdentity, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentity.Intercept(f(g(h())))`. +func (c *AuthIdentityClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentity = append(c.inters.AuthIdentity, interceptors...) +} + +// Create returns a builder for creating a AuthIdentity entity. +func (c *AuthIdentityClient) Create() *AuthIdentityCreate { + mutation := newAuthIdentityMutation(c.config, OpCreate) + return &AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentity entities. +func (c *AuthIdentityClient) CreateBulk(builders ...*AuthIdentityCreate) *AuthIdentityCreateBulk { + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityCreate, int)) *AuthIdentityCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityCreateBulk{err: fmt.Errorf("calling to AuthIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentity. +func (c *AuthIdentityClient) Update() *AuthIdentityUpdate { + mutation := newAuthIdentityMutation(c.config, OpUpdate) + return &AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityClient) UpdateOne(_m *AuthIdentity) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentity(_m)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityClient) UpdateOneID(id int64) *AuthIdentityUpdateOne { + mutation := newAuthIdentityMutation(c.config, OpUpdateOne, withAuthIdentityID(id)) + return &AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentity. +func (c *AuthIdentityClient) Delete() *AuthIdentityDelete { + mutation := newAuthIdentityMutation(c.config, OpDelete) + return &AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityClient) DeleteOne(_m *AuthIdentity) *AuthIdentityDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityClient) DeleteOneID(id int64) *AuthIdentityDeleteOne { + builder := c.Delete().Where(authidentity.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentity. +func (c *AuthIdentityClient) Query() *AuthIdentityQuery { + return &AuthIdentityQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentity}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentity entity by its id. +func (c *AuthIdentityClient) Get(ctx context.Context, id int64) (*AuthIdentity, error) { + return c.Query().Where(authidentity.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityClient) GetX(ctx context.Context, id int64) *AuthIdentity { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryUser(_m *AuthIdentity) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentity.UserTable, authidentity.UserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryChannels queries the channels edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryChannels(_m *AuthIdentity) *AuthIdentityChannelQuery { + query := (&AuthIdentityChannelClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(authidentitychannel.Table, authidentitychannel.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.ChannelsTable, authidentity.ChannelsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecisions queries the adoption_decisions edge of a AuthIdentity. +func (c *AuthIdentityClient) QueryAdoptionDecisions(_m *AuthIdentity) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentity.Table, authidentity.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, authidentity.AdoptionDecisionsTable, authidentity.AdoptionDecisionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityClient) Hooks() []Hook { + return c.hooks.AuthIdentity +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityClient) Interceptors() []Interceptor { + return c.inters.AuthIdentity +} + +func (c *AuthIdentityClient) mutate(ctx context.Context, m *AuthIdentityMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentity mutation op: %q", m.Op()) + } +} + +// AuthIdentityChannelClient is a client for the AuthIdentityChannel schema. +type AuthIdentityChannelClient struct { + config +} + +// NewAuthIdentityChannelClient returns a client for the AuthIdentityChannel from the given config. +func NewAuthIdentityChannelClient(c config) *AuthIdentityChannelClient { + return &AuthIdentityChannelClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authidentitychannel.Hooks(f(g(h())))`. +func (c *AuthIdentityChannelClient) Use(hooks ...Hook) { + c.hooks.AuthIdentityChannel = append(c.hooks.AuthIdentityChannel, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authidentitychannel.Intercept(f(g(h())))`. +func (c *AuthIdentityChannelClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthIdentityChannel = append(c.inters.AuthIdentityChannel, interceptors...) +} + +// Create returns a builder for creating a AuthIdentityChannel entity. +func (c *AuthIdentityChannelClient) Create() *AuthIdentityChannelCreate { + mutation := newAuthIdentityChannelMutation(c.config, OpCreate) + return &AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthIdentityChannel entities. +func (c *AuthIdentityChannelClient) CreateBulk(builders ...*AuthIdentityChannelCreate) *AuthIdentityChannelCreateBulk { + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthIdentityChannelClient) MapCreateBulk(slice any, setFunc func(*AuthIdentityChannelCreate, int)) *AuthIdentityChannelCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthIdentityChannelCreateBulk{err: fmt.Errorf("calling to AuthIdentityChannelClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthIdentityChannelCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthIdentityChannelCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Update() *AuthIdentityChannelUpdate { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdate) + return &AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthIdentityChannelClient) UpdateOne(_m *AuthIdentityChannel) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannel(_m)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthIdentityChannelClient) UpdateOneID(id int64) *AuthIdentityChannelUpdateOne { + mutation := newAuthIdentityChannelMutation(c.config, OpUpdateOne, withAuthIdentityChannelID(id)) + return &AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Delete() *AuthIdentityChannelDelete { + mutation := newAuthIdentityChannelMutation(c.config, OpDelete) + return &AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthIdentityChannelClient) DeleteOne(_m *AuthIdentityChannel) *AuthIdentityChannelDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthIdentityChannelClient) DeleteOneID(id int64) *AuthIdentityChannelDeleteOne { + builder := c.Delete().Where(authidentitychannel.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthIdentityChannelDeleteOne{builder} +} + +// Query returns a query builder for AuthIdentityChannel. +func (c *AuthIdentityChannelClient) Query() *AuthIdentityChannelQuery { + return &AuthIdentityChannelQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthIdentityChannel}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthIdentityChannel entity by its id. +func (c *AuthIdentityChannelClient) Get(ctx context.Context, id int64) (*AuthIdentityChannel, error) { + return c.Query().Where(authidentitychannel.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthIdentityChannelClient) GetX(ctx context.Context, id int64) *AuthIdentityChannel { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryIdentity queries the identity edge of a AuthIdentityChannel. +func (c *AuthIdentityChannelClient) QueryIdentity(_m *AuthIdentityChannel) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(authidentitychannel.Table, authidentitychannel.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, authidentitychannel.IdentityTable, authidentitychannel.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *AuthIdentityChannelClient) Hooks() []Hook { + return c.hooks.AuthIdentityChannel +} + +// Interceptors returns the client interceptors. +func (c *AuthIdentityChannelClient) Interceptors() []Interceptor { + return c.inters.AuthIdentityChannel +} + +func (c *AuthIdentityChannelClient) mutate(ctx context.Context, m *AuthIdentityChannelMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthIdentityChannelCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthIdentityChannelUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthIdentityChannelUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthIdentityChannelDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AuthIdentityChannel mutation op: %q", m.Op()) + } +} + // ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. type ErrorPassthroughRuleClient struct { config @@ -1760,6 +2124,171 @@ func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyReco } } +// IdentityAdoptionDecisionClient is a client for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecisionClient struct { + config +} + +// NewIdentityAdoptionDecisionClient returns a client for the IdentityAdoptionDecision from the given config. +func NewIdentityAdoptionDecisionClient(c config) *IdentityAdoptionDecisionClient { + return &IdentityAdoptionDecisionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `identityadoptiondecision.Hooks(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Use(hooks ...Hook) { + c.hooks.IdentityAdoptionDecision = append(c.hooks.IdentityAdoptionDecision, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `identityadoptiondecision.Intercept(f(g(h())))`. +func (c *IdentityAdoptionDecisionClient) Intercept(interceptors ...Interceptor) { + c.inters.IdentityAdoptionDecision = append(c.inters.IdentityAdoptionDecision, interceptors...) +} + +// Create returns a builder for creating a IdentityAdoptionDecision entity. +func (c *IdentityAdoptionDecisionClient) Create() *IdentityAdoptionDecisionCreate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpCreate) + return &IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdentityAdoptionDecision entities. +func (c *IdentityAdoptionDecisionClient) CreateBulk(builders ...*IdentityAdoptionDecisionCreate) *IdentityAdoptionDecisionCreateBulk { + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdentityAdoptionDecisionClient) MapCreateBulk(slice any, setFunc func(*IdentityAdoptionDecisionCreate, int)) *IdentityAdoptionDecisionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdentityAdoptionDecisionCreateBulk{err: fmt.Errorf("calling to IdentityAdoptionDecisionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdentityAdoptionDecisionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdentityAdoptionDecisionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Update() *IdentityAdoptionDecisionUpdate { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdate) + return &IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdentityAdoptionDecisionClient) UpdateOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecision(_m)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdentityAdoptionDecisionClient) UpdateOneID(id int64) *IdentityAdoptionDecisionUpdateOne { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpUpdateOne, withIdentityAdoptionDecisionID(id)) + return &IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Delete() *IdentityAdoptionDecisionDelete { + mutation := newIdentityAdoptionDecisionMutation(c.config, OpDelete) + return &IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdentityAdoptionDecisionClient) DeleteOne(_m *IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdentityAdoptionDecisionClient) DeleteOneID(id int64) *IdentityAdoptionDecisionDeleteOne { + builder := c.Delete().Where(identityadoptiondecision.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdentityAdoptionDecisionDeleteOne{builder} +} + +// Query returns a query builder for IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) Query() *IdentityAdoptionDecisionQuery { + return &IdentityAdoptionDecisionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdentityAdoptionDecision}, + inters: c.Interceptors(), + } +} + +// Get returns a IdentityAdoptionDecision entity by its id. +func (c *IdentityAdoptionDecisionClient) Get(ctx context.Context, id int64) (*IdentityAdoptionDecision, error) { + return c.Query().Where(identityadoptiondecision.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdentityAdoptionDecisionClient) GetX(ctx context.Context, id int64) *IdentityAdoptionDecision { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryPendingAuthSession queries the pending_auth_session edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryPendingAuthSession(_m *IdentityAdoptionDecision) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryIdentity queries the identity edge of a IdentityAdoptionDecision. +func (c *IdentityAdoptionDecisionClient) QueryIdentity(_m *IdentityAdoptionDecision) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *IdentityAdoptionDecisionClient) Hooks() []Hook { + return c.hooks.IdentityAdoptionDecision +} + +// Interceptors returns the client interceptors. +func (c *IdentityAdoptionDecisionClient) Interceptors() []Interceptor { + return c.inters.IdentityAdoptionDecision +} + +func (c *IdentityAdoptionDecisionClient) mutate(ctx context.Context, m *IdentityAdoptionDecisionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdentityAdoptionDecisionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdentityAdoptionDecisionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdentityAdoptionDecisionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdentityAdoptionDecisionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdentityAdoptionDecision mutation op: %q", m.Op()) + } +} + // PaymentAuditLogClient is a client for the PaymentAuditLog schema. type PaymentAuditLogClient struct { config @@ -2175,6 +2704,171 @@ func (c *PaymentProviderInstanceClient) mutate(ctx context.Context, m *PaymentPr } } +// PendingAuthSessionClient is a client for the PendingAuthSession schema. +type PendingAuthSessionClient struct { + config +} + +// NewPendingAuthSessionClient returns a client for the PendingAuthSession from the given config. +func NewPendingAuthSessionClient(c config) *PendingAuthSessionClient { + return &PendingAuthSessionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `pendingauthsession.Hooks(f(g(h())))`. +func (c *PendingAuthSessionClient) Use(hooks ...Hook) { + c.hooks.PendingAuthSession = append(c.hooks.PendingAuthSession, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `pendingauthsession.Intercept(f(g(h())))`. +func (c *PendingAuthSessionClient) Intercept(interceptors ...Interceptor) { + c.inters.PendingAuthSession = append(c.inters.PendingAuthSession, interceptors...) +} + +// Create returns a builder for creating a PendingAuthSession entity. +func (c *PendingAuthSessionClient) Create() *PendingAuthSessionCreate { + mutation := newPendingAuthSessionMutation(c.config, OpCreate) + return &PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of PendingAuthSession entities. +func (c *PendingAuthSessionClient) CreateBulk(builders ...*PendingAuthSessionCreate) *PendingAuthSessionCreateBulk { + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *PendingAuthSessionClient) MapCreateBulk(slice any, setFunc func(*PendingAuthSessionCreate, int)) *PendingAuthSessionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &PendingAuthSessionCreateBulk{err: fmt.Errorf("calling to PendingAuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*PendingAuthSessionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &PendingAuthSessionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Update() *PendingAuthSessionUpdate { + mutation := newPendingAuthSessionMutation(c.config, OpUpdate) + return &PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *PendingAuthSessionClient) UpdateOne(_m *PendingAuthSession) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSession(_m)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *PendingAuthSessionClient) UpdateOneID(id int64) *PendingAuthSessionUpdateOne { + mutation := newPendingAuthSessionMutation(c.config, OpUpdateOne, withPendingAuthSessionID(id)) + return &PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Delete() *PendingAuthSessionDelete { + mutation := newPendingAuthSessionMutation(c.config, OpDelete) + return &PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *PendingAuthSessionClient) DeleteOne(_m *PendingAuthSession) *PendingAuthSessionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *PendingAuthSessionClient) DeleteOneID(id int64) *PendingAuthSessionDeleteOne { + builder := c.Delete().Where(pendingauthsession.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &PendingAuthSessionDeleteOne{builder} +} + +// Query returns a query builder for PendingAuthSession. +func (c *PendingAuthSessionClient) Query() *PendingAuthSessionQuery { + return &PendingAuthSessionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypePendingAuthSession}, + inters: c.Interceptors(), + } +} + +// Get returns a PendingAuthSession entity by its id. +func (c *PendingAuthSessionClient) Get(ctx context.Context, id int64) (*PendingAuthSession, error) { + return c.Query().Where(pendingauthsession.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *PendingAuthSessionClient) GetX(ctx context.Context, id int64) *PendingAuthSession { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryTargetUser queries the target_user edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryTargetUser(_m *PendingAuthSession) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryAdoptionDecision queries the adoption_decision edge of a PendingAuthSession. +func (c *PendingAuthSessionClient) QueryAdoptionDecision(_m *PendingAuthSession) *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, id), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *PendingAuthSessionClient) Hooks() []Hook { + return c.hooks.PendingAuthSession +} + +// Interceptors returns the client interceptors. +func (c *PendingAuthSessionClient) Interceptors() []Interceptor { + return c.inters.PendingAuthSession +} + +func (c *PendingAuthSessionClient) mutate(ctx context.Context, m *PendingAuthSessionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&PendingAuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&PendingAuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&PendingAuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&PendingAuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown PendingAuthSession mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -3951,6 +4645,38 @@ func (c *UserClient) QueryPaymentOrders(_m *User) *PaymentOrderQuery { return query } +// QueryAuthIdentities queries the auth_identities edge of a User. +func (c *UserClient) QueryAuthIdentities(_m *User) *AuthIdentityQuery { + query := (&AuthIdentityClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryPendingAuthSessions queries the pending_auth_sessions edge of a User. +func (c *UserClient) QueryPendingAuthSessions(_m *User) *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := _m.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromV = sqlgraph.Neighbors(_m.driver.Dialect(), step) + return fromV, nil + } + return query +} + // QueryUserAllowedGroups queries the user_allowed_groups edge of a User. func (c *UserClient) QueryUserAllowedGroups(_m *User) *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: c.config}).Query() @@ -4628,18 +5354,20 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, IdempotencyRecord, PaymentAuditLog, PaymentOrder, - PaymentProviderInstance, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, AuthIdentity, + AuthIdentityChannel, ErrorPassthroughRule, Group, IdempotencyRecord, + IdentityAdoptionDecision, PaymentAuditLog, PaymentOrder, + PaymentProviderInstance, PendingAuthSession, PromoCode, PromoCodeUsage, Proxy, + RedeemCode, SecuritySecret, Setting, SubscriptionPlan, TLSFingerprintProfile, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 96ed5e03..339e5369 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -98,32 +102,36 @@ var ( func checkColumn(t, c string) error { initCheck.Do(func() { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ - apikey.Table: apikey.ValidColumn, - account.Table: account.ValidColumn, - accountgroup.Table: accountgroup.ValidColumn, - announcement.Table: announcement.ValidColumn, - announcementread.Table: announcementread.ValidColumn, - errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, - group.Table: group.ValidColumn, - idempotencyrecord.Table: idempotencyrecord.ValidColumn, - paymentauditlog.Table: paymentauditlog.ValidColumn, - paymentorder.Table: paymentorder.ValidColumn, - paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, - promocode.Table: promocode.ValidColumn, - promocodeusage.Table: promocodeusage.ValidColumn, - proxy.Table: proxy.ValidColumn, - redeemcode.Table: redeemcode.ValidColumn, - securitysecret.Table: securitysecret.ValidColumn, - setting.Table: setting.ValidColumn, - subscriptionplan.Table: subscriptionplan.ValidColumn, - tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, - usagecleanuptask.Table: usagecleanuptask.ValidColumn, - usagelog.Table: usagelog.ValidColumn, - user.Table: user.ValidColumn, - userallowedgroup.Table: userallowedgroup.ValidColumn, - userattributedefinition.Table: userattributedefinition.ValidColumn, - userattributevalue.Table: userattributevalue.ValidColumn, - usersubscription.Table: usersubscription.ValidColumn, + apikey.Table: apikey.ValidColumn, + account.Table: account.ValidColumn, + accountgroup.Table: accountgroup.ValidColumn, + announcement.Table: announcement.ValidColumn, + announcementread.Table: announcementread.ValidColumn, + authidentity.Table: authidentity.ValidColumn, + authidentitychannel.Table: authidentitychannel.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, + group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, + identityadoptiondecision.Table: identityadoptiondecision.ValidColumn, + paymentauditlog.Table: paymentauditlog.ValidColumn, + paymentorder.Table: paymentorder.ValidColumn, + paymentproviderinstance.Table: paymentproviderinstance.ValidColumn, + pendingauthsession.Table: pendingauthsession.ValidColumn, + promocode.Table: promocode.ValidColumn, + promocodeusage.Table: promocodeusage.ValidColumn, + proxy.Table: proxy.ValidColumn, + redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, + setting.Table: setting.ValidColumn, + subscriptionplan.Table: subscriptionplan.ValidColumn, + tlsfingerprintprofile.Table: tlsfingerprintprofile.ValidColumn, + usagecleanuptask.Table: usagecleanuptask.ValidColumn, + usagelog.Table: usagelog.ValidColumn, + user.Table: user.ValidColumn, + userallowedgroup.Table: userallowedgroup.ValidColumn, + userattributedefinition.Table: userattributedefinition.ValidColumn, + userattributevalue.Table: userattributevalue.ValidColumn, + usersubscription.Table: usersubscription.ValidColumn, }) }) return columnCheck(t, c) diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 199dacea..46ac02bc 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,30 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary +// function as AuthIdentity mutator. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityMutation", m) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary +// function as AuthIdentityChannel mutator. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthIdentityChannelFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.AuthIdentityChannelMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AuthIdentityChannelMutation", m) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary // function as ErrorPassthroughRule mutator. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) @@ -105,6 +129,18 @@ func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent. return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary +// function as IdentityAdoptionDecision mutator. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdentityAdoptionDecisionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdentityAdoptionDecisionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdentityAdoptionDecisionMutation", m) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary // function as PaymentAuditLog mutator. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogMutation) (ent.Value, error) @@ -141,6 +177,18 @@ func (f PaymentProviderInstanceFunc) Mutate(ctx context.Context, m ent.Mutation) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PaymentProviderInstanceMutation", m) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary +// function as PendingAuthSession mutator. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f PendingAuthSessionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.PendingAuthSessionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PendingAuthSessionMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) diff --git a/backend/ent/identityadoptiondecision.go b/backend/ent/identityadoptiondecision.go new file mode 100644 index 00000000..ecaee65c --- /dev/null +++ b/backend/ent/identityadoptiondecision.go @@ -0,0 +1,223 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecision is the model entity for the IdentityAdoptionDecision schema. +type IdentityAdoptionDecision struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // PendingAuthSessionID holds the value of the "pending_auth_session_id" field. + PendingAuthSessionID int64 `json:"pending_auth_session_id,omitempty"` + // IdentityID holds the value of the "identity_id" field. + IdentityID *int64 `json:"identity_id,omitempty"` + // AdoptDisplayName holds the value of the "adopt_display_name" field. + AdoptDisplayName bool `json:"adopt_display_name,omitempty"` + // AdoptAvatar holds the value of the "adopt_avatar" field. + AdoptAvatar bool `json:"adopt_avatar,omitempty"` + // DecidedAt holds the value of the "decided_at" field. + DecidedAt time.Time `json:"decided_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the IdentityAdoptionDecisionQuery when eager-loading is set. + Edges IdentityAdoptionDecisionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// IdentityAdoptionDecisionEdges holds the relations/edges for other nodes in the graph. +type IdentityAdoptionDecisionEdges struct { + // PendingAuthSession holds the value of the pending_auth_session edge. + PendingAuthSession *PendingAuthSession `json:"pending_auth_session,omitempty"` + // Identity holds the value of the identity edge. + Identity *AuthIdentity `json:"identity,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// PendingAuthSessionOrErr returns the PendingAuthSession value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) PendingAuthSessionOrErr() (*PendingAuthSession, error) { + if e.PendingAuthSession != nil { + return e.PendingAuthSession, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: pendingauthsession.Label} + } + return nil, &NotLoadedError{edge: "pending_auth_session"} +} + +// IdentityOrErr returns the Identity value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e IdentityAdoptionDecisionEdges) IdentityOrErr() (*AuthIdentity, error) { + if e.Identity != nil { + return e.Identity, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: authidentity.Label} + } + return nil, &NotLoadedError{edge: "identity"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdentityAdoptionDecision) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldAdoptDisplayName, identityadoptiondecision.FieldAdoptAvatar: + values[i] = new(sql.NullBool) + case identityadoptiondecision.FieldID, identityadoptiondecision.FieldPendingAuthSessionID, identityadoptiondecision.FieldIdentityID: + values[i] = new(sql.NullInt64) + case identityadoptiondecision.FieldCreatedAt, identityadoptiondecision.FieldUpdatedAt, identityadoptiondecision.FieldDecidedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdentityAdoptionDecision fields. +func (_m *IdentityAdoptionDecision) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case identityadoptiondecision.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case identityadoptiondecision.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case identityadoptiondecision.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case identityadoptiondecision.FieldPendingAuthSessionID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field pending_auth_session_id", values[i]) + } else if value.Valid { + _m.PendingAuthSessionID = value.Int64 + } + case identityadoptiondecision.FieldIdentityID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field identity_id", values[i]) + } else if value.Valid { + _m.IdentityID = new(int64) + *_m.IdentityID = value.Int64 + } + case identityadoptiondecision.FieldAdoptDisplayName: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_display_name", values[i]) + } else if value.Valid { + _m.AdoptDisplayName = value.Bool + } + case identityadoptiondecision.FieldAdoptAvatar: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field adopt_avatar", values[i]) + } else if value.Valid { + _m.AdoptAvatar = value.Bool + } + case identityadoptiondecision.FieldDecidedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field decided_at", values[i]) + } else if value.Valid { + _m.DecidedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdentityAdoptionDecision. +// This includes values selected through modifiers, order, etc. +func (_m *IdentityAdoptionDecision) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryPendingAuthSession queries the "pending_auth_session" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryPendingAuthSession() *PendingAuthSessionQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryPendingAuthSession(_m) +} + +// QueryIdentity queries the "identity" edge of the IdentityAdoptionDecision entity. +func (_m *IdentityAdoptionDecision) QueryIdentity() *AuthIdentityQuery { + return NewIdentityAdoptionDecisionClient(_m.config).QueryIdentity(_m) +} + +// Update returns a builder for updating this IdentityAdoptionDecision. +// Note that you need to call IdentityAdoptionDecision.Unwrap() before calling this method if this IdentityAdoptionDecision +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdentityAdoptionDecision) Update() *IdentityAdoptionDecisionUpdateOne { + return NewIdentityAdoptionDecisionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdentityAdoptionDecision entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdentityAdoptionDecision) Unwrap() *IdentityAdoptionDecision { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdentityAdoptionDecision is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdentityAdoptionDecision) String() string { + var builder strings.Builder + builder.WriteString("IdentityAdoptionDecision(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("pending_auth_session_id=") + builder.WriteString(fmt.Sprintf("%v", _m.PendingAuthSessionID)) + builder.WriteString(", ") + if v := _m.IdentityID; v != nil { + builder.WriteString("identity_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("adopt_display_name=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptDisplayName)) + builder.WriteString(", ") + builder.WriteString("adopt_avatar=") + builder.WriteString(fmt.Sprintf("%v", _m.AdoptAvatar)) + builder.WriteString(", ") + builder.WriteString("decided_at=") + builder.WriteString(_m.DecidedAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdentityAdoptionDecisions is a parsable slice of IdentityAdoptionDecision. +type IdentityAdoptionDecisions []*IdentityAdoptionDecision diff --git a/backend/ent/identityadoptiondecision/identityadoptiondecision.go b/backend/ent/identityadoptiondecision/identityadoptiondecision.go new file mode 100644 index 00000000..93adaf73 --- /dev/null +++ b/backend/ent/identityadoptiondecision/identityadoptiondecision.go @@ -0,0 +1,159 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the identityadoptiondecision type in the database. + Label = "identity_adoption_decision" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldPendingAuthSessionID holds the string denoting the pending_auth_session_id field in the database. + FieldPendingAuthSessionID = "pending_auth_session_id" + // FieldIdentityID holds the string denoting the identity_id field in the database. + FieldIdentityID = "identity_id" + // FieldAdoptDisplayName holds the string denoting the adopt_display_name field in the database. + FieldAdoptDisplayName = "adopt_display_name" + // FieldAdoptAvatar holds the string denoting the adopt_avatar field in the database. + FieldAdoptAvatar = "adopt_avatar" + // FieldDecidedAt holds the string denoting the decided_at field in the database. + FieldDecidedAt = "decided_at" + // EdgePendingAuthSession holds the string denoting the pending_auth_session edge name in mutations. + EdgePendingAuthSession = "pending_auth_session" + // EdgeIdentity holds the string denoting the identity edge name in mutations. + EdgeIdentity = "identity" + // Table holds the table name of the identityadoptiondecision in the database. + Table = "identity_adoption_decisions" + // PendingAuthSessionTable is the table that holds the pending_auth_session relation/edge. + PendingAuthSessionTable = "identity_adoption_decisions" + // PendingAuthSessionInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionInverseTable = "pending_auth_sessions" + // PendingAuthSessionColumn is the table column denoting the pending_auth_session relation/edge. + PendingAuthSessionColumn = "pending_auth_session_id" + // IdentityTable is the table that holds the identity relation/edge. + IdentityTable = "identity_adoption_decisions" + // IdentityInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + IdentityInverseTable = "auth_identities" + // IdentityColumn is the table column denoting the identity relation/edge. + IdentityColumn = "identity_id" +) + +// Columns holds all SQL columns for identityadoptiondecision fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldPendingAuthSessionID, + FieldIdentityID, + FieldAdoptDisplayName, + FieldAdoptAvatar, + FieldDecidedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // DefaultAdoptDisplayName holds the default value on creation for the "adopt_display_name" field. + DefaultAdoptDisplayName bool + // DefaultAdoptAvatar holds the default value on creation for the "adopt_avatar" field. + DefaultAdoptAvatar bool + // DefaultDecidedAt holds the default value on creation for the "decided_at" field. + DefaultDecidedAt func() time.Time +) + +// OrderOption defines the ordering options for the IdentityAdoptionDecision queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionID orders the results by the pending_auth_session_id field. +func ByPendingAuthSessionID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPendingAuthSessionID, opts...).ToFunc() +} + +// ByIdentityID orders the results by the identity_id field. +func ByIdentityID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdentityID, opts...).ToFunc() +} + +// ByAdoptDisplayName orders the results by the adopt_display_name field. +func ByAdoptDisplayName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptDisplayName, opts...).ToFunc() +} + +// ByAdoptAvatar orders the results by the adopt_avatar field. +func ByAdoptAvatar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAdoptAvatar, opts...).ToFunc() +} + +// ByDecidedAt orders the results by the decided_at field. +func ByDecidedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDecidedAt, opts...).ToFunc() +} + +// ByPendingAuthSessionField orders the results by pending_auth_session field. +func ByPendingAuthSessionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionStep(), sql.OrderByField(field, opts...)) + } +} + +// ByIdentityField orders the results by identity field. +func ByIdentityField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newIdentityStep(), sql.OrderByField(field, opts...)) + } +} +func newPendingAuthSessionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) +} +func newIdentityStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(IdentityInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) +} diff --git a/backend/ent/identityadoptiondecision/where.go b/backend/ent/identityadoptiondecision/where.go new file mode 100644 index 00000000..1968f175 --- /dev/null +++ b/backend/ent/identityadoptiondecision/where.go @@ -0,0 +1,342 @@ +// Code generated by ent, DO NOT EDIT. + +package identityadoptiondecision + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// PendingAuthSessionID applies equality check predicate on the "pending_auth_session_id" field. It's identical to PendingAuthSessionIDEQ. +func PendingAuthSessionID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// IdentityID applies equality check predicate on the "identity_id" field. It's identical to IdentityIDEQ. +func IdentityID(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// AdoptDisplayName applies equality check predicate on the "adopt_display_name" field. It's identical to AdoptDisplayNameEQ. +func AdoptDisplayName(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatar applies equality check predicate on the "adopt_avatar" field. It's identical to AdoptAvatarEQ. +func AdoptAvatar(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// DecidedAt applies equality check predicate on the "decided_at" field. It's identical to DecidedAtEQ. +func DecidedAt(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// PendingAuthSessionIDEQ applies the EQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDNEQ applies the NEQ predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldPendingAuthSessionID, v)) +} + +// PendingAuthSessionIDIn applies the In predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldPendingAuthSessionID, vs...)) +} + +// PendingAuthSessionIDNotIn applies the NotIn predicate on the "pending_auth_session_id" field. +func PendingAuthSessionIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldPendingAuthSessionID, vs...)) +} + +// IdentityIDEQ applies the EQ predicate on the "identity_id" field. +func IdentityIDEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldIdentityID, v)) +} + +// IdentityIDNEQ applies the NEQ predicate on the "identity_id" field. +func IdentityIDNEQ(v int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldIdentityID, v)) +} + +// IdentityIDIn applies the In predicate on the "identity_id" field. +func IdentityIDIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldIdentityID, vs...)) +} + +// IdentityIDNotIn applies the NotIn predicate on the "identity_id" field. +func IdentityIDNotIn(vs ...int64) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldIdentityID, vs...)) +} + +// IdentityIDIsNil applies the IsNil predicate on the "identity_id" field. +func IdentityIDIsNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIsNull(FieldIdentityID)) +} + +// IdentityIDNotNil applies the NotNil predicate on the "identity_id" field. +func IdentityIDNotNil() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotNull(FieldIdentityID)) +} + +// AdoptDisplayNameEQ applies the EQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptDisplayName, v)) +} + +// AdoptDisplayNameNEQ applies the NEQ predicate on the "adopt_display_name" field. +func AdoptDisplayNameNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptDisplayName, v)) +} + +// AdoptAvatarEQ applies the EQ predicate on the "adopt_avatar" field. +func AdoptAvatarEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldAdoptAvatar, v)) +} + +// AdoptAvatarNEQ applies the NEQ predicate on the "adopt_avatar" field. +func AdoptAvatarNEQ(v bool) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldAdoptAvatar, v)) +} + +// DecidedAtEQ applies the EQ predicate on the "decided_at" field. +func DecidedAtEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldEQ(FieldDecidedAt, v)) +} + +// DecidedAtNEQ applies the NEQ predicate on the "decided_at" field. +func DecidedAtNEQ(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNEQ(FieldDecidedAt, v)) +} + +// DecidedAtIn applies the In predicate on the "decided_at" field. +func DecidedAtIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldIn(FieldDecidedAt, vs...)) +} + +// DecidedAtNotIn applies the NotIn predicate on the "decided_at" field. +func DecidedAtNotIn(vs ...time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldNotIn(FieldDecidedAt, vs...)) +} + +// DecidedAtGT applies the GT predicate on the "decided_at" field. +func DecidedAtGT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGT(FieldDecidedAt, v)) +} + +// DecidedAtGTE applies the GTE predicate on the "decided_at" field. +func DecidedAtGTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldGTE(FieldDecidedAt, v)) +} + +// DecidedAtLT applies the LT predicate on the "decided_at" field. +func DecidedAtLT(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLT(FieldDecidedAt, v)) +} + +// DecidedAtLTE applies the LTE predicate on the "decided_at" field. +func DecidedAtLTE(v time.Time) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.FieldLTE(FieldDecidedAt, v)) +} + +// HasPendingAuthSession applies the HasEdge predicate on the "pending_auth_session" edge. +func HasPendingAuthSession() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, PendingAuthSessionTable, PendingAuthSessionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionWith applies the HasEdge predicate on the "pending_auth_session" edge with a given conditions (other predicates). +func HasPendingAuthSessionWith(preds ...predicate.PendingAuthSession) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newPendingAuthSessionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasIdentity applies the HasEdge predicate on the "identity" edge. +func HasIdentity() predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, IdentityTable, IdentityColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasIdentityWith applies the HasEdge predicate on the "identity" edge with a given conditions (other predicates). +func HasIdentityWith(preds ...predicate.AuthIdentity) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + step := newIdentityStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdentityAdoptionDecision) predicate.IdentityAdoptionDecision { + return predicate.IdentityAdoptionDecision(sql.NotPredicates(p)) +} diff --git a/backend/ent/identityadoptiondecision_create.go b/backend/ent/identityadoptiondecision_create.go new file mode 100644 index 00000000..491ba9f9 --- /dev/null +++ b/backend/ent/identityadoptiondecision_create.go @@ -0,0 +1,843 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" +) + +// IdentityAdoptionDecisionCreate is the builder for creating a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionCreate struct { + config + mutation *IdentityAdoptionDecisionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetCreatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableCreatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableUpdatedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetPendingAuthSessionID(v) + return _c +} + +// SetIdentityID sets the "identity_id" field. +func (_c *IdentityAdoptionDecisionCreate) SetIdentityID(v int64) *IdentityAdoptionDecisionCreate { + _c.mutation.SetIdentityID(v) + return _c +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetIdentityID(*v) + } + return _c +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptDisplayName(v) + return _c +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptDisplayName(*v) + } + return _c +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_c *IdentityAdoptionDecisionCreate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionCreate { + _c.mutation.SetAdoptAvatar(v) + return _c +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetAdoptAvatar(*v) + } + return _c +} + +// SetDecidedAt sets the "decided_at" field. +func (_c *IdentityAdoptionDecisionCreate) SetDecidedAt(v time.Time) *IdentityAdoptionDecisionCreate { + _c.mutation.SetDecidedAt(v) + return _c +} + +// SetNillableDecidedAt sets the "decided_at" field if the given value is not nil. +func (_c *IdentityAdoptionDecisionCreate) SetNillableDecidedAt(v *time.Time) *IdentityAdoptionDecisionCreate { + if v != nil { + _c.SetDecidedAt(*v) + } + return _c +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_c *IdentityAdoptionDecisionCreate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionCreate { + return _c.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_c *IdentityAdoptionDecisionCreate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionCreate { + return _c.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_c *IdentityAdoptionDecisionCreate) Mutation() *IdentityAdoptionDecisionMutation { + return _c.mutation +} + +// Save creates the IdentityAdoptionDecision in the database. +func (_c *IdentityAdoptionDecisionCreate) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdentityAdoptionDecisionCreate) SaveX(ctx context.Context) *IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdentityAdoptionDecisionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := identityadoptiondecision.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + v := identityadoptiondecision.DefaultAdoptDisplayName + _c.mutation.SetAdoptDisplayName(v) + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + v := identityadoptiondecision.DefaultAdoptAvatar + _c.mutation.SetAdoptAvatar(v) + } + if _, ok := _c.mutation.DecidedAt(); !ok { + v := identityadoptiondecision.DefaultDecidedAt() + _c.mutation.SetDecidedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdentityAdoptionDecisionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.updated_at"`)} + } + if _, ok := _c.mutation.PendingAuthSessionID(); !ok { + return &ValidationError{Name: "pending_auth_session_id", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.pending_auth_session_id"`)} + } + if _, ok := _c.mutation.AdoptDisplayName(); !ok { + return &ValidationError{Name: "adopt_display_name", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_display_name"`)} + } + if _, ok := _c.mutation.AdoptAvatar(); !ok { + return &ValidationError{Name: "adopt_avatar", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.adopt_avatar"`)} + } + if _, ok := _c.mutation.DecidedAt(); !ok { + return &ValidationError{Name: "decided_at", err: errors.New(`ent: missing required field "IdentityAdoptionDecision.decided_at"`)} + } + if len(_c.mutation.PendingAuthSessionIDs()) == 0 { + return &ValidationError{Name: "pending_auth_session", err: errors.New(`ent: missing required edge "IdentityAdoptionDecision.pending_auth_session"`)} + } + return nil +} + +func (_c *IdentityAdoptionDecisionCreate) sqlSave(ctx context.Context) (*IdentityAdoptionDecision, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdentityAdoptionDecisionCreate) createSpec() (*IdentityAdoptionDecision, *sqlgraph.CreateSpec) { + var ( + _node = &IdentityAdoptionDecision{config: _c.config} + _spec = sqlgraph.NewCreateSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + _node.AdoptDisplayName = value + } + if value, ok := _c.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + _node.AdoptAvatar = value + } + if value, ok := _c.mutation.DecidedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldDecidedAt, field.TypeTime, value) + _node.DecidedAt = value + } + if nodes := _c.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.PendingAuthSessionID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.IdentityID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreate) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertOne{ + create: _c, + } +} + +type ( + // IdentityAdoptionDecisionUpsertOne is the builder for "upsert"-ing + // one IdentityAdoptionDecision node. + IdentityAdoptionDecisionUpsertOne struct { + create *IdentityAdoptionDecisionCreate + } + + // IdentityAdoptionDecisionUpsert is the "OnConflict" setter. + IdentityAdoptionDecisionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsert) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldUpdatedAt) + return u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldPendingAuthSessionID, v) + return u +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldPendingAuthSessionID) + return u +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldIdentityID, v) + return u +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldIdentityID) + return u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsert) ClearIdentityID() *IdentityAdoptionDecisionUpsert { + u.SetNull(identityadoptiondecision.FieldIdentityID) + return u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptDisplayName, v) + return u +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptDisplayName) + return u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsert) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsert { + u.Set(identityadoptiondecision.FieldAdoptAvatar, v) + return u +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsert) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsert { + u.SetExcluded(identityadoptiondecision.FieldAdoptAvatar) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) UpdateNewValues() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := u.create.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertOne) Ignore() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertOne) DoNothing() *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreate.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertOne) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertOne) ClearIdentityID() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertOne) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertOne { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdentityAdoptionDecisionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdentityAdoptionDecisionCreateBulk is the builder for creating many IdentityAdoptionDecision entities in bulk. +type IdentityAdoptionDecisionCreateBulk struct { + config + err error + builders []*IdentityAdoptionDecisionCreate + conflict []sql.ConflictOption +} + +// Save creates the IdentityAdoptionDecision entities in the database. +func (_c *IdentityAdoptionDecisionCreateBulk) Save(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdentityAdoptionDecision, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdentityAdoptionDecisionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) SaveX(ctx context.Context) []*IdentityAdoptionDecision { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdentityAdoptionDecisionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdentityAdoptionDecisionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdentityAdoptionDecision.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdentityAdoptionDecisionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = opts + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdentityAdoptionDecisionCreateBulk) OnConflictColumns(columns ...string) *IdentityAdoptionDecisionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdentityAdoptionDecisionUpsertBulk{ + create: _c, + } +} + +// IdentityAdoptionDecisionUpsertBulk is the builder for "upsert"-ing +// a bulk of IdentityAdoptionDecision nodes. +type IdentityAdoptionDecisionUpsertBulk struct { + create *IdentityAdoptionDecisionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateNewValues() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldCreatedAt) + } + if _, exists := b.mutation.DecidedAt(); exists { + s.SetIgnore(identityadoptiondecision.FieldDecidedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdentityAdoptionDecision.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdentityAdoptionDecisionUpsertBulk) Ignore() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdentityAdoptionDecisionUpsertBulk) DoNothing() *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdentityAdoptionDecisionCreateBulk.OnConflict +// documentation for more info. +func (u *IdentityAdoptionDecisionUpsertBulk) Update(set func(*IdentityAdoptionDecisionUpsert)) *IdentityAdoptionDecisionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdentityAdoptionDecisionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateUpdatedAt() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetPendingAuthSessionID(v) + }) +} + +// UpdatePendingAuthSessionID sets the "pending_auth_session_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdatePendingAuthSessionID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdatePendingAuthSessionID() + }) +} + +// SetIdentityID sets the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetIdentityID(v int64) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetIdentityID(v) + }) +} + +// UpdateIdentityID sets the "identity_id" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateIdentityID() + }) +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (u *IdentityAdoptionDecisionUpsertBulk) ClearIdentityID() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.ClearIdentityID() + }) +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptDisplayName(v) + }) +} + +// UpdateAdoptDisplayName sets the "adopt_display_name" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptDisplayName() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptDisplayName() + }) +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (u *IdentityAdoptionDecisionUpsertBulk) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.SetAdoptAvatar(v) + }) +} + +// UpdateAdoptAvatar sets the "adopt_avatar" field to the value that was provided on create. +func (u *IdentityAdoptionDecisionUpsertBulk) UpdateAdoptAvatar() *IdentityAdoptionDecisionUpsertBulk { + return u.Update(func(s *IdentityAdoptionDecisionUpsert) { + s.UpdateAdoptAvatar() + }) +} + +// Exec executes the query. +func (u *IdentityAdoptionDecisionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdentityAdoptionDecisionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdentityAdoptionDecisionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdentityAdoptionDecisionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_delete.go b/backend/ent/identityadoptiondecision_delete.go new file mode 100644 index 00000000..ef3d328d --- /dev/null +++ b/backend/ent/identityadoptiondecision_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionDelete is the builder for deleting a IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDelete struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDelete) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdentityAdoptionDecisionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdentityAdoptionDecisionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(identityadoptiondecision.Table, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdentityAdoptionDecisionDeleteOne is the builder for deleting a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionDeleteOne struct { + _d *IdentityAdoptionDecisionDelete +} + +// Where appends a list predicates to the IdentityAdoptionDecisionDelete builder. +func (_d *IdentityAdoptionDecisionDeleteOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdentityAdoptionDecisionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{identityadoptiondecision.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdentityAdoptionDecisionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/identityadoptiondecision_query.go b/backend/ent/identityadoptiondecision_query.go new file mode 100644 index 00000000..4082d8ee --- /dev/null +++ b/backend/ent/identityadoptiondecision_query.go @@ -0,0 +1,721 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionQuery is the builder for querying IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionQuery struct { + config + ctx *QueryContext + order []identityadoptiondecision.OrderOption + inters []Interceptor + predicates []predicate.IdentityAdoptionDecision + withPendingAuthSession *PendingAuthSessionQuery + withIdentity *AuthIdentityQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdentityAdoptionDecisionQuery builder. +func (_q *IdentityAdoptionDecisionQuery) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdentityAdoptionDecisionQuery) Limit(limit int) *IdentityAdoptionDecisionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdentityAdoptionDecisionQuery) Offset(offset int) *IdentityAdoptionDecisionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdentityAdoptionDecisionQuery) Unique(unique bool) *IdentityAdoptionDecisionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdentityAdoptionDecisionQuery) Order(o ...identityadoptiondecision.OrderOption) *IdentityAdoptionDecisionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryPendingAuthSession chains the current query on the "pending_auth_session" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryPendingAuthSession() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, identityadoptiondecision.PendingAuthSessionTable, identityadoptiondecision.PendingAuthSessionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryIdentity chains the current query on the "identity" edge. +func (_q *IdentityAdoptionDecisionQuery) QueryIdentity() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(identityadoptiondecision.Table, identityadoptiondecision.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, identityadoptiondecision.IdentityTable, identityadoptiondecision.IdentityColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first IdentityAdoptionDecision entity from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision was found. +func (_q *IdentityAdoptionDecisionQuery) First(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{identityadoptiondecision.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdentityAdoptionDecision ID from the query. +// Returns a *NotFoundError when no IdentityAdoptionDecision ID was found. +func (_q *IdentityAdoptionDecisionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{identityadoptiondecision.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdentityAdoptionDecision entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision entity is found. +// Returns a *NotFoundError when no IdentityAdoptionDecision entities are found. +func (_q *IdentityAdoptionDecisionQuery) Only(ctx context.Context) (*IdentityAdoptionDecision, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{identityadoptiondecision.Label} + default: + return nil, &NotSingularError{identityadoptiondecision.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdentityAdoptionDecision ID in the query. +// Returns a *NotSingularError when more than one IdentityAdoptionDecision ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdentityAdoptionDecisionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{identityadoptiondecision.Label} + default: + err = &NotSingularError{identityadoptiondecision.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdentityAdoptionDecisions. +func (_q *IdentityAdoptionDecisionQuery) All(ctx context.Context) ([]*IdentityAdoptionDecision, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdentityAdoptionDecision, *IdentityAdoptionDecisionQuery]() + return withInterceptors[[]*IdentityAdoptionDecision](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) AllX(ctx context.Context) []*IdentityAdoptionDecision { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdentityAdoptionDecision IDs. +func (_q *IdentityAdoptionDecisionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(identityadoptiondecision.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdentityAdoptionDecisionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdentityAdoptionDecisionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdentityAdoptionDecisionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdentityAdoptionDecisionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdentityAdoptionDecisionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdentityAdoptionDecisionQuery) Clone() *IdentityAdoptionDecisionQuery { + if _q == nil { + return nil + } + return &IdentityAdoptionDecisionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]identityadoptiondecision.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdentityAdoptionDecision{}, _q.predicates...), + withPendingAuthSession: _q.withPendingAuthSession.Clone(), + withIdentity: _q.withIdentity.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithPendingAuthSession tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_session" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithPendingAuthSession(opts ...func(*PendingAuthSessionQuery)) *IdentityAdoptionDecisionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSession = query + return _q +} + +// WithIdentity tells the query-builder to eager-load the nodes that are connected to +// the "identity" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *IdentityAdoptionDecisionQuery) WithIdentity(opts ...func(*AuthIdentityQuery)) *IdentityAdoptionDecisionQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withIdentity = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// GroupBy(identityadoptiondecision.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) GroupBy(field string, fields ...string) *IdentityAdoptionDecisionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdentityAdoptionDecisionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = identityadoptiondecision.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdentityAdoptionDecision.Query(). +// Select(identityadoptiondecision.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdentityAdoptionDecisionQuery) Select(fields ...string) *IdentityAdoptionDecisionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdentityAdoptionDecisionSelect{IdentityAdoptionDecisionQuery: _q} + sbuild.label = identityadoptiondecision.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdentityAdoptionDecisionSelect configured with the given aggregations. +func (_q *IdentityAdoptionDecisionQuery) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdentityAdoptionDecisionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !identityadoptiondecision.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdentityAdoptionDecision, error) { + var ( + nodes = []*IdentityAdoptionDecision{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withPendingAuthSession != nil, + _q.withIdentity != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdentityAdoptionDecision).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdentityAdoptionDecision{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withPendingAuthSession; query != nil { + if err := _q.loadPendingAuthSession(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *PendingAuthSession) { n.Edges.PendingAuthSession = e }); err != nil { + return nil, err + } + } + if query := _q.withIdentity; query != nil { + if err := _q.loadIdentity(ctx, query, nodes, nil, + func(n *IdentityAdoptionDecision, e *AuthIdentity) { n.Edges.Identity = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *IdentityAdoptionDecisionQuery) loadPendingAuthSession(ctx context.Context, query *PendingAuthSessionQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *PendingAuthSession)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + fk := nodes[i].PendingAuthSessionID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(pendingauthsession.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "pending_auth_session_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *IdentityAdoptionDecisionQuery) loadIdentity(ctx context.Context, query *AuthIdentityQuery, nodes []*IdentityAdoptionDecision, init func(*IdentityAdoptionDecision), assign func(*IdentityAdoptionDecision, *AuthIdentity)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*IdentityAdoptionDecision) + for i := range nodes { + if nodes[i].IdentityID == nil { + continue + } + fk := *nodes[i].IdentityID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(authidentity.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "identity_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (_q *IdentityAdoptionDecisionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdentityAdoptionDecisionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for i := range fields { + if fields[i] != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withPendingAuthSession != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + if _q.withIdentity != nil { + _spec.Node.AddColumnOnce(identityadoptiondecision.FieldIdentityID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdentityAdoptionDecisionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(identityadoptiondecision.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = identityadoptiondecision.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdentityAdoptionDecisionQuery) ForUpdate(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdentityAdoptionDecisionQuery) ForShare(opts ...sql.LockOption) *IdentityAdoptionDecisionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdentityAdoptionDecisionGroupBy is the group-by builder for IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionGroupBy struct { + selector + build *IdentityAdoptionDecisionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdentityAdoptionDecisionGroupBy) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdentityAdoptionDecisionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdentityAdoptionDecisionGroupBy) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdentityAdoptionDecisionSelect is the builder for selecting fields of IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionSelect struct { + *IdentityAdoptionDecisionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdentityAdoptionDecisionSelect) Aggregate(fns ...AggregateFunc) *IdentityAdoptionDecisionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdentityAdoptionDecisionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdentityAdoptionDecisionQuery, *IdentityAdoptionDecisionSelect](ctx, _s.IdentityAdoptionDecisionQuery, _s, _s.inters, v) +} + +func (_s *IdentityAdoptionDecisionSelect) sqlScan(ctx context.Context, root *IdentityAdoptionDecisionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/identityadoptiondecision_update.go b/backend/ent/identityadoptiondecision_update.go new file mode 100644 index 00000000..0ca21d27 --- /dev/null +++ b/backend/ent/identityadoptiondecision_update.go @@ -0,0 +1,532 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdentityAdoptionDecisionUpdate is the builder for updating IdentityAdoptionDecision entities. +type IdentityAdoptionDecisionUpdate struct { + config + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdate) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdate) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentityID() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdate) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdate { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdate) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdate { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdate { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdate { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdate) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdate) ClearIdentity() *IdentityAdoptionDecisionUpdate { + _u.mutation.ClearIdentity() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdentityAdoptionDecisionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdentityAdoptionDecisionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdate) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdentityAdoptionDecisionUpdateOne is the builder for updating a single IdentityAdoptionDecision entity. +type IdentityAdoptionDecisionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdentityAdoptionDecisionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetUpdatedAt(v time.Time) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSessionID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetPendingAuthSessionID(v) + return _u +} + +// SetNillablePendingAuthSessionID sets the "pending_auth_session_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillablePendingAuthSessionID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetPendingAuthSessionID(*v) + } + return _u +} + +// SetIdentityID sets the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentityID(v int64) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetIdentityID(v) + return _u +} + +// SetNillableIdentityID sets the "identity_id" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableIdentityID(v *int64) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetIdentityID(*v) + } + return _u +} + +// ClearIdentityID clears the value of the "identity_id" field. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentityID() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentityID() + return _u +} + +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptDisplayName(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptDisplayName(v) + return _u +} + +// SetNillableAdoptDisplayName sets the "adopt_display_name" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptDisplayName(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptDisplayName(*v) + } + return _u +} + +// SetAdoptAvatar sets the "adopt_avatar" field. +func (_u *IdentityAdoptionDecisionUpdateOne) SetAdoptAvatar(v bool) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.SetAdoptAvatar(v) + return _u +} + +// SetNillableAdoptAvatar sets the "adopt_avatar" field if the given value is not nil. +func (_u *IdentityAdoptionDecisionUpdateOne) SetNillableAdoptAvatar(v *bool) *IdentityAdoptionDecisionUpdateOne { + if v != nil { + _u.SetAdoptAvatar(*v) + } + return _u +} + +// SetPendingAuthSession sets the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetPendingAuthSession(v *PendingAuthSession) *IdentityAdoptionDecisionUpdateOne { + return _u.SetPendingAuthSessionID(v.ID) +} + +// SetIdentity sets the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) SetIdentity(v *AuthIdentity) *IdentityAdoptionDecisionUpdateOne { + return _u.SetIdentityID(v.ID) +} + +// Mutation returns the IdentityAdoptionDecisionMutation object of the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Mutation() *IdentityAdoptionDecisionMutation { + return _u.mutation +} + +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearPendingAuthSession() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearPendingAuthSession() + return _u +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (_u *IdentityAdoptionDecisionUpdateOne) ClearIdentity() *IdentityAdoptionDecisionUpdateOne { + _u.mutation.ClearIdentity() + return _u +} + +// Where appends a list predicates to the IdentityAdoptionDecisionUpdate builder. +func (_u *IdentityAdoptionDecisionUpdateOne) Where(ps ...predicate.IdentityAdoptionDecision) *IdentityAdoptionDecisionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdentityAdoptionDecisionUpdateOne) Select(field string, fields ...string) *IdentityAdoptionDecisionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdentityAdoptionDecision entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Save(ctx context.Context) (*IdentityAdoptionDecision, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) SaveX(ctx context.Context) *IdentityAdoptionDecision { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdentityAdoptionDecisionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdentityAdoptionDecisionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdentityAdoptionDecisionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := identityadoptiondecision.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdentityAdoptionDecisionUpdateOne) check() error { + if _u.mutation.PendingAuthSessionCleared() && len(_u.mutation.PendingAuthSessionIDs()) > 0 { + return errors.New(`ent: clearing a required unique edge "IdentityAdoptionDecision.pending_auth_session"`) + } + return nil +} + +func (_u *IdentityAdoptionDecisionUpdateOne) sqlSave(ctx context.Context) (_node *IdentityAdoptionDecision, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(identityadoptiondecision.Table, identityadoptiondecision.Columns, sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdentityAdoptionDecision.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, identityadoptiondecision.FieldID) + for _, f := range fields { + if !identityadoptiondecision.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != identityadoptiondecision.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(identityadoptiondecision.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.AdoptDisplayName(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptDisplayName, field.TypeBool, value) + } + if value, ok := _u.mutation.AdoptAvatar(); ok { + _spec.SetField(identityadoptiondecision.FieldAdoptAvatar, field.TypeBool, value) + } + if _u.mutation.PendingAuthSessionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: true, + Table: identityadoptiondecision.PendingAuthSessionTable, + Columns: []string{identityadoptiondecision.PendingAuthSessionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.IdentityCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.IdentityIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: identityadoptiondecision.IdentityTable, + Columns: []string{identityadoptiondecision.IdentityColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &IdentityAdoptionDecision{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{identityadoptiondecision.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8d8320bb..157c5122 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,12 +13,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -228,6 +232,60 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The AuthIdentityFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityFunc func(context.Context, *ent.AuthIdentityQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The TraverseAuthIdentity type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentity func(context.Context, *ent.AuthIdentityQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentity) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentity) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityQuery", q) +} + +// The AuthIdentityChannelFunc type is an adapter to allow the use of ordinary function as a Querier. +type AuthIdentityChannelFunc func(context.Context, *ent.AuthIdentityChannelQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f AuthIdentityChannelFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + +// The TraverseAuthIdentityChannel type is an adapter to allow the use of ordinary function as Traverser. +type TraverseAuthIdentityChannel func(context.Context, *ent.AuthIdentityChannelQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseAuthIdentityChannel) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseAuthIdentityChannel) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.AuthIdentityChannelQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.AuthIdentityChannelQuery", q) +} + // The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) @@ -309,6 +367,33 @@ func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) er return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) } +// The IdentityAdoptionDecisionFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdentityAdoptionDecisionFunc func(context.Context, *ent.IdentityAdoptionDecisionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdentityAdoptionDecisionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + +// The TraverseIdentityAdoptionDecision type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdentityAdoptionDecision func(context.Context, *ent.IdentityAdoptionDecisionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdentityAdoptionDecision) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdentityAdoptionDecision) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdentityAdoptionDecisionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdentityAdoptionDecisionQuery", q) +} + // The PaymentAuditLogFunc type is an adapter to allow the use of ordinary function as a Querier. type PaymentAuditLogFunc func(context.Context, *ent.PaymentAuditLogQuery) (ent.Value, error) @@ -390,6 +475,33 @@ func (f TraversePaymentProviderInstance) Traverse(ctx context.Context, q ent.Que return fmt.Errorf("unexpected query type %T. expect *ent.PaymentProviderInstanceQuery", q) } +// The PendingAuthSessionFunc type is an adapter to allow the use of ordinary function as a Querier. +type PendingAuthSessionFunc func(context.Context, *ent.PendingAuthSessionQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f PendingAuthSessionFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + +// The TraversePendingAuthSession type is an adapter to allow the use of ordinary function as Traverser. +type TraversePendingAuthSession func(context.Context, *ent.PendingAuthSessionQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraversePendingAuthSession) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraversePendingAuthSession) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.PendingAuthSessionQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.PendingAuthSessionQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -808,18 +920,26 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.AuthIdentityQuery: + return &query[*ent.AuthIdentityQuery, predicate.AuthIdentity, authidentity.OrderOption]{typ: ent.TypeAuthIdentity, tq: q}, nil + case *ent.AuthIdentityChannelQuery: + return &query[*ent.AuthIdentityChannelQuery, predicate.AuthIdentityChannel, authidentitychannel.OrderOption]{typ: ent.TypeAuthIdentityChannel, tq: q}, nil case *ent.ErrorPassthroughRuleQuery: return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.IdempotencyRecordQuery: return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil + case *ent.IdentityAdoptionDecisionQuery: + return &query[*ent.IdentityAdoptionDecisionQuery, predicate.IdentityAdoptionDecision, identityadoptiondecision.OrderOption]{typ: ent.TypeIdentityAdoptionDecision, tq: q}, nil case *ent.PaymentAuditLogQuery: return &query[*ent.PaymentAuditLogQuery, predicate.PaymentAuditLog, paymentauditlog.OrderOption]{typ: ent.TypePaymentAuditLog, tq: q}, nil case *ent.PaymentOrderQuery: return &query[*ent.PaymentOrderQuery, predicate.PaymentOrder, paymentorder.OrderOption]{typ: ent.TypePaymentOrder, tq: q}, nil case *ent.PaymentProviderInstanceQuery: return &query[*ent.PaymentProviderInstanceQuery, predicate.PaymentProviderInstance, paymentproviderinstance.OrderOption]{typ: ent.TypePaymentProviderInstance, tq: q}, nil + case *ent.PendingAuthSessionQuery: + return &query[*ent.PendingAuthSessionQuery, predicate.PendingAuthSession, pendingauthsession.OrderOption]{typ: ent.TypePendingAuthSession, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 68bdbf55..bf41e73b 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -338,6 +338,89 @@ var ( }, }, } + // AuthIdentitiesColumns holds the columns for the "auth_identities" table. + AuthIdentitiesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "issuer", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "user_id", Type: field.TypeInt64}, + } + // AuthIdentitiesTable holds the schema information for the "auth_identities" table. + AuthIdentitiesTable = &schema.Table{ + Name: "auth_identities", + Columns: AuthIdentitiesColumns, + PrimaryKey: []*schema.Column{AuthIdentitiesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identities_users_auth_identities", + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentity_provider_type_provider_key_provider_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentitiesColumns[3], AuthIdentitiesColumns[4], AuthIdentitiesColumns[5]}, + }, + { + Name: "authidentity_user_id", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9]}, + }, + { + Name: "authidentity_user_id_provider_type", + Unique: false, + Columns: []*schema.Column{AuthIdentitiesColumns[9], AuthIdentitiesColumns[3]}, + }, + }, + } + // AuthIdentityChannelsColumns holds the columns for the "auth_identity_channels" table. + AuthIdentityChannelsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel", Type: field.TypeString, Size: 20}, + {Name: "channel_app_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "channel_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "identity_id", Type: field.TypeInt64}, + } + // AuthIdentityChannelsTable holds the schema information for the "auth_identity_channels" table. + AuthIdentityChannelsTable = &schema.Table{ + Name: "auth_identity_channels", + Columns: AuthIdentityChannelsColumns, + PrimaryKey: []*schema.Column{AuthIdentityChannelsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "auth_identity_channels_auth_identities_channels", + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "authidentitychannel_provider_type_provider_key_channel_channel_app_id_channel_subject", + Unique: true, + Columns: []*schema.Column{AuthIdentityChannelsColumns[3], AuthIdentityChannelsColumns[4], AuthIdentityChannelsColumns[5], AuthIdentityChannelsColumns[6], AuthIdentityChannelsColumns[7]}, + }, + { + Name: "authidentitychannel_identity_id", + Unique: false, + Columns: []*schema.Column{AuthIdentityChannelsColumns[9]}, + }, + }, + } // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. ErrorPassthroughRulesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -485,6 +568,49 @@ var ( }, }, } + // IdentityAdoptionDecisionsColumns holds the columns for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "adopt_display_name", Type: field.TypeBool, Default: false}, + {Name: "adopt_avatar", Type: field.TypeBool, Default: false}, + {Name: "decided_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "identity_id", Type: field.TypeInt64, Nullable: true}, + {Name: "pending_auth_session_id", Type: field.TypeInt64, Unique: true}, + } + // IdentityAdoptionDecisionsTable holds the schema information for the "identity_adoption_decisions" table. + IdentityAdoptionDecisionsTable = &schema.Table{ + Name: "identity_adoption_decisions", + Columns: IdentityAdoptionDecisionsColumns, + PrimaryKey: []*schema.Column{IdentityAdoptionDecisionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "identity_adoption_decisions_auth_identities_adoption_decisions", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + RefColumns: []*schema.Column{AuthIdentitiesColumns[0]}, + OnDelete: schema.SetNull, + }, + { + Symbol: "identity_adoption_decisions_pending_auth_sessions_adoption_decision", + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + RefColumns: []*schema.Column{PendingAuthSessionsColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "identityadoptiondecision_pending_auth_session_id", + Unique: true, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[7]}, + }, + { + Name: "identityadoptiondecision_identity_id", + Unique: false, + Columns: []*schema.Column{IdentityAdoptionDecisionsColumns[6]}, + }, + }, + } // PaymentAuditLogsColumns holds the columns for the "payment_audit_logs" table. PaymentAuditLogsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -638,6 +764,72 @@ var ( }, }, } + // PendingAuthSessionsColumns holds the columns for the "pending_auth_sessions" table. + PendingAuthSessionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "session_token", Type: field.TypeString, Size: 255}, + {Name: "intent", Type: field.TypeString, Size: 40}, + {Name: "provider_type", Type: field.TypeString, Size: 20}, + {Name: "provider_key", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "provider_subject", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "redirect_to", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "resolved_email", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "registration_password_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "upstream_identity_claims", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "local_flow_state", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "browser_session_key", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_hash", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "completion_code_expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "email_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "password_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "totp_verified_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "expires_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "consumed_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "target_user_id", Type: field.TypeInt64, Nullable: true}, + } + // PendingAuthSessionsTable holds the schema information for the "pending_auth_sessions" table. + PendingAuthSessionsTable = &schema.Table{ + Name: "pending_auth_sessions", + Columns: PendingAuthSessionsColumns, + PrimaryKey: []*schema.Column{PendingAuthSessionsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "pending_auth_sessions_users_pending_auth_sessions", + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.SetNull, + }, + }, + Indexes: []*schema.Index{ + { + Name: "pendingauthsession_session_token", + Unique: true, + Columns: []*schema.Column{PendingAuthSessionsColumns[3]}, + }, + { + Name: "pendingauthsession_target_user_id", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[21]}, + }, + { + Name: "pendingauthsession_expires_at", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[19]}, + }, + { + Name: "pendingauthsession_provider_type_provider_key_provider_subject", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[5], PendingAuthSessionsColumns[6], PendingAuthSessionsColumns[7]}, + }, + { + Name: "pendingauthsession_completion_code_hash", + Unique: false, + Columns: []*schema.Column{PendingAuthSessionsColumns[14]}, + }, + }, + } // PromoCodesColumns holds the columns for the "promo_codes" table. PromoCodesColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -1079,6 +1271,9 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"}, + {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"}, {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -1318,12 +1513,16 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + AuthIdentitiesTable, + AuthIdentityChannelsTable, ErrorPassthroughRulesTable, GroupsTable, IdempotencyRecordsTable, + IdentityAdoptionDecisionsTable, PaymentAuditLogsTable, PaymentOrdersTable, PaymentProviderInstancesTable, + PendingAuthSessionsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, @@ -1365,6 +1564,14 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + AuthIdentitiesTable.ForeignKeys[0].RefTable = UsersTable + AuthIdentitiesTable.Annotation = &entsql.Annotation{ + Table: "auth_identities", + } + AuthIdentityChannelsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + AuthIdentityChannelsTable.Annotation = &entsql.Annotation{ + Table: "auth_identity_channels", + } ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ Table: "error_passthrough_rules", } @@ -1374,6 +1581,11 @@ func init() { IdempotencyRecordsTable.Annotation = &entsql.Annotation{ Table: "idempotency_records", } + IdentityAdoptionDecisionsTable.ForeignKeys[0].RefTable = AuthIdentitiesTable + IdentityAdoptionDecisionsTable.ForeignKeys[1].RefTable = PendingAuthSessionsTable + IdentityAdoptionDecisionsTable.Annotation = &entsql.Annotation{ + Table: "identity_adoption_decisions", + } PaymentAuditLogsTable.Annotation = &entsql.Annotation{ Table: "payment_audit_logs", } @@ -1384,6 +1596,10 @@ func init() { PaymentProviderInstancesTable.Annotation = &entsql.Annotation{ Table: "payment_provider_instances", } + PendingAuthSessionsTable.ForeignKeys[0].RefTable = UsersTable + PendingAuthSessionsTable.Annotation = &entsql.Annotation{ + Table: "pending_auth_sessions", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 524ccb92..12905c9a 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,12 +17,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -51,32 +55,36 @@ const ( OpUpdateOne = ent.OpUpdateOne // Node types. - TypeAPIKey = "APIKey" - TypeAccount = "Account" - TypeAccountGroup = "AccountGroup" - TypeAnnouncement = "Announcement" - TypeAnnouncementRead = "AnnouncementRead" - TypeErrorPassthroughRule = "ErrorPassthroughRule" - TypeGroup = "Group" - TypeIdempotencyRecord = "IdempotencyRecord" - TypePaymentAuditLog = "PaymentAuditLog" - TypePaymentOrder = "PaymentOrder" - TypePaymentProviderInstance = "PaymentProviderInstance" - TypePromoCode = "PromoCode" - TypePromoCodeUsage = "PromoCodeUsage" - TypeProxy = "Proxy" - TypeRedeemCode = "RedeemCode" - TypeSecuritySecret = "SecuritySecret" - TypeSetting = "Setting" - TypeSubscriptionPlan = "SubscriptionPlan" - TypeTLSFingerprintProfile = "TLSFingerprintProfile" - TypeUsageCleanupTask = "UsageCleanupTask" - TypeUsageLog = "UsageLog" - TypeUser = "User" - TypeUserAllowedGroup = "UserAllowedGroup" - TypeUserAttributeDefinition = "UserAttributeDefinition" - TypeUserAttributeValue = "UserAttributeValue" - TypeUserSubscription = "UserSubscription" + TypeAPIKey = "APIKey" + TypeAccount = "Account" + TypeAccountGroup = "AccountGroup" + TypeAnnouncement = "Announcement" + TypeAnnouncementRead = "AnnouncementRead" + TypeAuthIdentity = "AuthIdentity" + TypeAuthIdentityChannel = "AuthIdentityChannel" + TypeErrorPassthroughRule = "ErrorPassthroughRule" + TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" + TypeIdentityAdoptionDecision = "IdentityAdoptionDecision" + TypePaymentAuditLog = "PaymentAuditLog" + TypePaymentOrder = "PaymentOrder" + TypePaymentProviderInstance = "PaymentProviderInstance" + TypePendingAuthSession = "PendingAuthSession" + TypePromoCode = "PromoCode" + TypePromoCodeUsage = "PromoCodeUsage" + TypeProxy = "Proxy" + TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" + TypeSetting = "Setting" + TypeSubscriptionPlan = "SubscriptionPlan" + TypeTLSFingerprintProfile = "TLSFingerprintProfile" + TypeUsageCleanupTask = "UsageCleanupTask" + TypeUsageLog = "UsageLog" + TypeUser = "User" + TypeUserAllowedGroup = "UserAllowedGroup" + TypeUserAttributeDefinition = "UserAttributeDefinition" + TypeUserAttributeValue = "UserAttributeValue" + TypeUserSubscription = "UserSubscription" ) // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. @@ -6887,49 +6895,45 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } -// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. -type ErrorPassthroughRuleMutation struct { +// AuthIdentityMutation represents an operation that mutates the AuthIdentity nodes in the graph. +type AuthIdentityMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - name *string - enabled *bool - priority *int - addpriority *int - error_codes *[]int - appenderror_codes []int - keywords *[]string - appendkeywords []string - match_mode *string - platforms *[]string - appendplatforms []string - passthrough_code *bool - response_code *int - addresponse_code *int - passthrough_body *bool - custom_message *string - skip_monitoring *bool - description *string - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*ErrorPassthroughRule, error) - predicates []predicate.ErrorPassthroughRule + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + provider_subject *string + verified_at *time.Time + issuer *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + user *int64 + cleareduser bool + channels map[int64]struct{} + removedchannels map[int64]struct{} + clearedchannels bool + adoption_decisions map[int64]struct{} + removedadoption_decisions map[int64]struct{} + clearedadoption_decisions bool + done bool + oldValue func(context.Context) (*AuthIdentity, error) + predicates []predicate.AuthIdentity } -var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) +var _ ent.Mutation = (*AuthIdentityMutation)(nil) -// errorpassthroughruleOption allows management of the mutation configuration using functional options. -type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) +// authidentityOption allows management of the mutation configuration using functional options. +type authidentityOption func(*AuthIdentityMutation) -// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. -func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { - m := &ErrorPassthroughRuleMutation{ +// newAuthIdentityMutation creates new mutation for the AuthIdentity entity. +func newAuthIdentityMutation(c config, op Op, opts ...authidentityOption) *AuthIdentityMutation { + m := &AuthIdentityMutation{ config: c, op: op, - typ: TypeErrorPassthroughRule, + typ: TypeAuthIdentity, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -6938,20 +6942,20 @@ func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughru return m } -// withErrorPassthroughRuleID sets the ID field of the mutation. -func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { - return func(m *ErrorPassthroughRuleMutation) { +// withAuthIdentityID sets the ID field of the mutation. +func withAuthIdentityID(id int64) authidentityOption { + return func(m *AuthIdentityMutation) { var ( err error once sync.Once - value *ErrorPassthroughRule + value *AuthIdentity ) - m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { + m.oldValue = func(ctx context.Context) (*AuthIdentity, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) + value, err = m.Client().AuthIdentity.Get(ctx, id) } }) return value, err @@ -6960,10 +6964,10 @@ func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { } } -// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. -func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { - return func(m *ErrorPassthroughRuleMutation) { - m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { +// withAuthIdentity sets the old AuthIdentity of the mutation. +func withAuthIdentity(node *AuthIdentity) authidentityOption { + return func(m *AuthIdentityMutation) { + m.oldValue = func(context.Context) (*AuthIdentity, error) { return node, nil } m.id = &node.ID @@ -6972,7 +6976,7 @@ func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOp // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m ErrorPassthroughRuleMutation) Client() *Client { +func (m AuthIdentityMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -6980,7 +6984,7 @@ func (m ErrorPassthroughRuleMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { +func (m AuthIdentityMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -6991,7 +6995,7 @@ func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { +func (m *AuthIdentityMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -7002,7 +7006,7 @@ func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *AuthIdentityMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -7011,19 +7015,19 @@ func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) + return m.Client().AuthIdentity.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } // SetCreatedAt sets the "created_at" field. -func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { +func (m *AuthIdentityMutation) SetCreatedAt(t time.Time) { m.created_at = &t } // CreatedAt returns the value of the "created_at" field in the mutation. -func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityMutation) CreatedAt() (r time.Time, exists bool) { v := m.created_at if v == nil { return @@ -7031,10 +7035,10 @@ func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -7049,17 +7053,17 @@ func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time } // ResetCreatedAt resets all changes to the "created_at" field. -func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { +func (m *AuthIdentityMutation) ResetCreatedAt() { m.created_at = nil } // SetUpdatedAt sets the "updated_at" field. -func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { +func (m *AuthIdentityMutation) SetUpdatedAt(t time.Time) { m.updated_at = &t } // UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { +func (m *AuthIdentityMutation) UpdatedAt() (r time.Time, exists bool) { v := m.updated_at if v == nil { return @@ -7067,10 +7071,10 @@ func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *AuthIdentityMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -7085,942 +7089,1510 @@ func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time } // ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { +func (m *AuthIdentityMutation) ResetUpdatedAt() { m.updated_at = nil } -// SetName sets the "name" field. -func (m *ErrorPassthroughRuleMutation) SetName(s string) { - m.name = &s +// SetUserID sets the "user_id" field. +func (m *AuthIdentityMutation) SetUserID(i int64) { + m.user = &i } -// Name returns the value of the "name" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { - v := m.name +// UserID returns the value of the "user_id" field in the mutation. +func (m *AuthIdentityMutation) UserID() (r int64, exists bool) { + v := m.user if v == nil { return } return *v, true } -// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldUserID returns the old "user_id" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { +func (m *AuthIdentityMutation) OldUserID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") + return v, errors.New("OldUserID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + return v, errors.New("OldUserID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + return v, fmt.Errorf("querying old value for OldUserID: %w", err) } - return oldValue.Name, nil + return oldValue.UserID, nil } -// ResetName resets all changes to the "name" field. -func (m *ErrorPassthroughRuleMutation) ResetName() { - m.name = nil +// ResetUserID resets all changes to the "user_id" field. +func (m *AuthIdentityMutation) ResetUserID() { + m.user = nil } -// SetEnabled sets the "enabled" field. -func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { - m.enabled = &b +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityMutation) SetProviderType(s string) { + m.provider_type = &s } -// Enabled returns the value of the "enabled" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { - v := m.enabled +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityMutation) ProviderType() (r string, exists bool) { + v := m.provider_type if v == nil { return } return *v, true } -// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldProviderType returns the old "provider_type" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { +func (m *AuthIdentityMutation) OldProviderType(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldEnabled requires an ID field in the mutation") + return v, errors.New("OldProviderType requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) } - return oldValue.Enabled, nil + return oldValue.ProviderType, nil } -// ResetEnabled resets all changes to the "enabled" field. -func (m *ErrorPassthroughRuleMutation) ResetEnabled() { - m.enabled = nil +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityMutation) ResetProviderType() { + m.provider_type = nil } -// SetPriority sets the "priority" field. -func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { - m.priority = &i - m.addpriority = nil +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityMutation) SetProviderKey(s string) { + m.provider_key = &s } -// Priority returns the value of the "priority" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { - v := m.priority +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key if v == nil { return } return *v, true } -// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { +func (m *AuthIdentityMutation) OldProviderKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPriority is only allowed on UpdateOne operations") + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPriority requires an ID field in the mutation") + return v, errors.New("OldProviderKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPriority: %w", err) + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) } - return oldValue.Priority, nil + return oldValue.ProviderKey, nil } -// AddPriority adds i to the "priority" field. -func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { - if m.addpriority != nil { - *m.addpriority += i - } else { - m.addpriority = &i - } +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityMutation) ResetProviderKey() { + m.provider_key = nil } -// AddedPriority returns the value that was added to the "priority" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { - v := m.addpriority +// SetProviderSubject sets the "provider_subject" field. +func (m *AuthIdentityMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *AuthIdentityMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject if v == nil { return } return *v, true } -// ResetPriority resets all changes to the "priority" field. -func (m *ErrorPassthroughRuleMutation) ResetPriority() { - m.priority = nil - m.addpriority = nil +// OldProviderSubject returns the old "provider_subject" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil } -// SetErrorCodes sets the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { - m.error_codes = &i - m.appenderror_codes = nil +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *AuthIdentityMutation) ResetProviderSubject() { + m.provider_subject = nil } -// ErrorCodes returns the value of the "error_codes" field in the mutation. -func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { - v := m.error_codes +// SetVerifiedAt sets the "verified_at" field. +func (m *AuthIdentityMutation) SetVerifiedAt(t time.Time) { + m.verified_at = &t +} + +// VerifiedAt returns the value of the "verified_at" field in the mutation. +func (m *AuthIdentityMutation) VerifiedAt() (r time.Time, exists bool) { + v := m.verified_at if v == nil { return } return *v, true } -// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldVerifiedAt returns the old "verified_at" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { +func (m *AuthIdentityMutation) OldVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") + return v, errors.New("OldVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldErrorCodes requires an ID field in the mutation") + return v, errors.New("OldVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) - } - return oldValue.ErrorCodes, nil -} - -// AppendErrorCodes adds i to the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { - m.appenderror_codes = append(m.appenderror_codes, i...) -} - -// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { - if len(m.appenderror_codes) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldVerifiedAt: %w", err) } - return m.appenderror_codes, true + return oldValue.VerifiedAt, nil } -// ClearErrorCodes clears the value of the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { - m.error_codes = nil - m.appenderror_codes = nil - m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} +// ClearVerifiedAt clears the value of the "verified_at" field. +func (m *AuthIdentityMutation) ClearVerifiedAt() { + m.verified_at = nil + m.clearedFields[authidentity.FieldVerifiedAt] = struct{}{} } -// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] +// VerifiedAtCleared returns if the "verified_at" field was cleared in this mutation. +func (m *AuthIdentityMutation) VerifiedAtCleared() bool { + _, ok := m.clearedFields[authidentity.FieldVerifiedAt] return ok } -// ResetErrorCodes resets all changes to the "error_codes" field. -func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { - m.error_codes = nil - m.appenderror_codes = nil - delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) +// ResetVerifiedAt resets all changes to the "verified_at" field. +func (m *AuthIdentityMutation) ResetVerifiedAt() { + m.verified_at = nil + delete(m.clearedFields, authidentity.FieldVerifiedAt) } -// SetKeywords sets the "keywords" field. -func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { - m.keywords = &s - m.appendkeywords = nil +// SetIssuer sets the "issuer" field. +func (m *AuthIdentityMutation) SetIssuer(s string) { + m.issuer = &s } -// Keywords returns the value of the "keywords" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { - v := m.keywords +// Issuer returns the value of the "issuer" field in the mutation. +func (m *AuthIdentityMutation) Issuer() (r string, exists bool) { + v := m.issuer if v == nil { return } return *v, true } -// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldIssuer returns the old "issuer" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { +func (m *AuthIdentityMutation) OldIssuer(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldKeywords is only allowed on UpdateOne operations") + return v, errors.New("OldIssuer is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldKeywords requires an ID field in the mutation") + return v, errors.New("OldIssuer requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldKeywords: %w", err) - } - return oldValue.Keywords, nil -} - -// AppendKeywords adds s to the "keywords" field. -func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { - m.appendkeywords = append(m.appendkeywords, s...) -} - -// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { - if len(m.appendkeywords) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldIssuer: %w", err) } - return m.appendkeywords, true + return oldValue.Issuer, nil } -// ClearKeywords clears the value of the "keywords" field. -func (m *ErrorPassthroughRuleMutation) ClearKeywords() { - m.keywords = nil - m.appendkeywords = nil - m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +// ClearIssuer clears the value of the "issuer" field. +func (m *AuthIdentityMutation) ClearIssuer() { + m.issuer = nil + m.clearedFields[authidentity.FieldIssuer] = struct{}{} } -// KeywordsCleared returns if the "keywords" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] +// IssuerCleared returns if the "issuer" field was cleared in this mutation. +func (m *AuthIdentityMutation) IssuerCleared() bool { + _, ok := m.clearedFields[authidentity.FieldIssuer] return ok } -// ResetKeywords resets all changes to the "keywords" field. -func (m *ErrorPassthroughRuleMutation) ResetKeywords() { - m.keywords = nil - m.appendkeywords = nil - delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +// ResetIssuer resets all changes to the "issuer" field. +func (m *AuthIdentityMutation) ResetIssuer() { + m.issuer = nil + delete(m.clearedFields, authidentity.FieldIssuer) } -// SetMatchMode sets the "match_mode" field. -func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { - m.match_mode = &s +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value } -// MatchMode returns the value of the "match_mode" field in the mutation. -func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { - v := m.match_mode +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata if v == nil { return } return *v, true } -// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// OldMetadata returns the old "metadata" field's value of the AuthIdentity entity. +// If the AuthIdentity object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { +func (m *AuthIdentityMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMatchMode requires an ID field in the mutation") + return v, errors.New("OldMetadata requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) } - return oldValue.MatchMode, nil + return oldValue.Metadata, nil } -// ResetMatchMode resets all changes to the "match_mode" field. -func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { - m.match_mode = nil +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityMutation) ResetMetadata() { + m.metadata = nil } -// SetPlatforms sets the "platforms" field. -func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { - m.platforms = &s - m.appendplatforms = nil +// ClearUser clears the "user" edge to the User entity. +func (m *AuthIdentityMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[authidentity.FieldUserID] = struct{}{} } -// Platforms returns the value of the "platforms" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { - v := m.platforms - if v == nil { - return - } - return *v, true +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *AuthIdentityMutation) UserCleared() bool { + return m.cleareduser } -// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlatforms requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) } - return oldValue.Platforms, nil + return } -// AppendPlatforms adds s to the "platforms" field. -func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { - m.appendplatforms = append(m.appendplatforms, s...) +// ResetUser resets all changes to the "user" edge. +func (m *AuthIdentityMutation) ResetUser() { + m.user = nil + m.cleareduser = false } -// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { - if len(m.appendplatforms) == 0 { - return nil, false +// AddChannelIDs adds the "channels" edge to the AuthIdentityChannel entity by ids. +func (m *AuthIdentityMutation) AddChannelIDs(ids ...int64) { + if m.channels == nil { + m.channels = make(map[int64]struct{}) + } + for i := range ids { + m.channels[ids[i]] = struct{}{} } - return m.appendplatforms, true } -// ClearPlatforms clears the value of the "platforms" field. -func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { - m.platforms = nil - m.appendplatforms = nil - m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} +// ClearChannels clears the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) ClearChannels() { + m.clearedchannels = true } -// PlatformsCleared returns if the "platforms" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] - return ok +// ChannelsCleared reports if the "channels" edge to the AuthIdentityChannel entity was cleared. +func (m *AuthIdentityMutation) ChannelsCleared() bool { + return m.clearedchannels } -// ResetPlatforms resets all changes to the "platforms" field. -func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { - m.platforms = nil - m.appendplatforms = nil - delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) +// RemoveChannelIDs removes the "channels" edge to the AuthIdentityChannel entity by IDs. +func (m *AuthIdentityMutation) RemoveChannelIDs(ids ...int64) { + if m.removedchannels == nil { + m.removedchannels = make(map[int64]struct{}) + } + for i := range ids { + delete(m.channels, ids[i]) + m.removedchannels[ids[i]] = struct{}{} + } } -// SetPassthroughCode sets the "passthrough_code" field. -func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { - m.passthrough_code = &b +// RemovedChannels returns the removed IDs of the "channels" edge to the AuthIdentityChannel entity. +func (m *AuthIdentityMutation) RemovedChannelsIDs() (ids []int64) { + for id := range m.removedchannels { + ids = append(ids, id) + } + return } -// PassthroughCode returns the value of the "passthrough_code" field in the mutation. -func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { - v := m.passthrough_code - if v == nil { - return +// ChannelsIDs returns the "channels" edge IDs in the mutation. +func (m *AuthIdentityMutation) ChannelsIDs() (ids []int64) { + for id := range m.channels { + ids = append(ids, id) } - return *v, true + return } -// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassthroughCode requires an ID field in the mutation") +// ResetChannels resets all changes to the "channels" edge. +func (m *AuthIdentityMutation) ResetChannels() { + m.channels = nil + m.clearedchannels = false + m.removedchannels = nil +} + +// AddAdoptionDecisionIDs adds the "adoption_decisions" edge to the IdentityAdoptionDecision entity by ids. +func (m *AuthIdentityMutation) AddAdoptionDecisionIDs(ids ...int64) { + if m.adoption_decisions == nil { + m.adoption_decisions = make(map[int64]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) + for i := range ids { + m.adoption_decisions[ids[i]] = struct{}{} } - return oldValue.PassthroughCode, nil } -// ResetPassthroughCode resets all changes to the "passthrough_code" field. -func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { - m.passthrough_code = nil +// ClearAdoptionDecisions clears the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) ClearAdoptionDecisions() { + m.clearedadoption_decisions = true } -// SetResponseCode sets the "response_code" field. -func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { - m.response_code = &i - m.addresponse_code = nil +// AdoptionDecisionsCleared reports if the "adoption_decisions" edge to the IdentityAdoptionDecision entity was cleared. +func (m *AuthIdentityMutation) AdoptionDecisionsCleared() bool { + return m.clearedadoption_decisions } -// ResponseCode returns the value of the "response_code" field in the mutation. -func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { - v := m.response_code - if v == nil { - return +// RemoveAdoptionDecisionIDs removes the "adoption_decisions" edge to the IdentityAdoptionDecision entity by IDs. +func (m *AuthIdentityMutation) RemoveAdoptionDecisionIDs(ids ...int64) { + if m.removedadoption_decisions == nil { + m.removedadoption_decisions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.adoption_decisions, ids[i]) + m.removedadoption_decisions[ids[i]] = struct{}{} } - return *v, true } -// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseCode requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) +// RemovedAdoptionDecisions returns the removed IDs of the "adoption_decisions" edge to the IdentityAdoptionDecision entity. +func (m *AuthIdentityMutation) RemovedAdoptionDecisionsIDs() (ids []int64) { + for id := range m.removedadoption_decisions { + ids = append(ids, id) } - return oldValue.ResponseCode, nil + return } -// AddResponseCode adds i to the "response_code" field. -func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { - if m.addresponse_code != nil { - *m.addresponse_code += i - } else { - m.addresponse_code = &i +// AdoptionDecisionsIDs returns the "adoption_decisions" edge IDs in the mutation. +func (m *AuthIdentityMutation) AdoptionDecisionsIDs() (ids []int64) { + for id := range m.adoption_decisions { + ids = append(ids, id) } + return } -// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { - v := m.addresponse_code - if v == nil { - return - } - return *v, true +// ResetAdoptionDecisions resets all changes to the "adoption_decisions" edge. +func (m *AuthIdentityMutation) ResetAdoptionDecisions() { + m.adoption_decisions = nil + m.clearedadoption_decisions = false + m.removedadoption_decisions = nil } -// ClearResponseCode clears the value of the "response_code" field. -func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { - m.response_code = nil - m.addresponse_code = nil - m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} +// Where appends a list predicates to the AuthIdentityMutation builder. +func (m *AuthIdentityMutation) Where(ps ...predicate.AuthIdentity) { + m.predicates = append(m.predicates, ps...) } -// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] - return ok +// WhereP appends storage-level predicates to the AuthIdentityMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentity, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// ResetResponseCode resets all changes to the "response_code" field. -func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { - m.response_code = nil - m.addresponse_code = nil - delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) +// Op returns the operation name. +func (m *AuthIdentityMutation) Op() Op { + return m.op } -// SetPassthroughBody sets the "passthrough_body" field. -func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { - m.passthrough_body = &b +// SetOp allows setting the mutation operation. +func (m *AuthIdentityMutation) SetOp(op Op) { + m.op = op } -// PassthroughBody returns the value of the "passthrough_body" field in the mutation. -func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { - v := m.passthrough_body - if v == nil { - return - } - return *v, true +// Type returns the node type of this mutation (AuthIdentity). +func (m *AuthIdentityMutation) Type() string { + return m.typ } -// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentity.FieldCreatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPassthroughBody requires an ID field in the mutation") + if m.updated_at != nil { + fields = append(fields, authidentity.FieldUpdatedAt) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) + if m.user != nil { + fields = append(fields, authidentity.FieldUserID) } - return oldValue.PassthroughBody, nil -} - -// ResetPassthroughBody resets all changes to the "passthrough_body" field. -func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { - m.passthrough_body = nil -} - -// SetCustomMessage sets the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { - m.custom_message = &s -} - -// CustomMessage returns the value of the "custom_message" field in the mutation. -func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { - v := m.custom_message - if v == nil { - return + if m.provider_type != nil { + fields = append(fields, authidentity.FieldProviderType) } - return *v, true -} - -// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") + if m.provider_key != nil { + fields = append(fields, authidentity.FieldProviderKey) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCustomMessage requires an ID field in the mutation") + if m.provider_subject != nil { + fields = append(fields, authidentity.FieldProviderSubject) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) + if m.verified_at != nil { + fields = append(fields, authidentity.FieldVerifiedAt) } - return oldValue.CustomMessage, nil -} - -// ClearCustomMessage clears the value of the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { - m.custom_message = nil - m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} + if m.issuer != nil { + fields = append(fields, authidentity.FieldIssuer) + } + if m.metadata != nil { + fields = append(fields, authidentity.FieldMetadata) + } + return fields } -// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] - return ok +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentity.FieldCreatedAt: + return m.CreatedAt() + case authidentity.FieldUpdatedAt: + return m.UpdatedAt() + case authidentity.FieldUserID: + return m.UserID() + case authidentity.FieldProviderType: + return m.ProviderType() + case authidentity.FieldProviderKey: + return m.ProviderKey() + case authidentity.FieldProviderSubject: + return m.ProviderSubject() + case authidentity.FieldVerifiedAt: + return m.VerifiedAt() + case authidentity.FieldIssuer: + return m.Issuer() + case authidentity.FieldMetadata: + return m.Metadata() + } + return nil, false } -// ResetCustomMessage resets all changes to the "custom_message" field. -func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { - m.custom_message = nil - delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentity.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentity.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentity.FieldUserID: + return m.OldUserID(ctx) + case authidentity.FieldProviderType: + return m.OldProviderType(ctx) + case authidentity.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentity.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case authidentity.FieldVerifiedAt: + return m.OldVerifiedAt(ctx) + case authidentity.FieldIssuer: + return m.OldIssuer(ctx) + case authidentity.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentity field %s", name) } -// SetSkipMonitoring sets the "skip_monitoring" field. -func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { - m.skip_monitoring = &b +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentity.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authidentity.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case authidentity.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case authidentity.FieldProviderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderType(v) + return nil + case authidentity.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil + case authidentity.FieldProviderSubject: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderSubject(v) + return nil + case authidentity.FieldVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetVerifiedAt(v) + return nil + case authidentity.FieldIssuer: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIssuer(v) + return nil + case authidentity.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown AuthIdentity field %s", name) } -// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. -func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { - v := m.skip_monitoring - if v == nil { - return +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthIdentityMutation) AddedFields() []string { + var fields []string + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthIdentityMutation) AddedField(name string) (ent.Value, bool) { + switch name { } - return *v, true + return nil, false } -// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityMutation) AddField(name string, value ent.Value) error { + switch name { } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") + return fmt.Errorf("unknown AuthIdentity numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthIdentityMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(authidentity.FieldVerifiedAt) { + fields = append(fields, authidentity.FieldVerifiedAt) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) + if m.FieldCleared(authidentity.FieldIssuer) { + fields = append(fields, authidentity.FieldIssuer) } - return oldValue.SkipMonitoring, nil + return fields } -// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. -func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { - m.skip_monitoring = nil +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthIdentityMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// SetDescription sets the "description" field. -func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { - m.description = &s +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ClearField(name string) error { + switch name { + case authidentity.FieldVerifiedAt: + m.ClearVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ClearIssuer() + return nil + } + return fmt.Errorf("unknown AuthIdentity nullable field %s", name) } -// Description returns the value of the "description" field in the mutation. -func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { - v := m.description - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthIdentityMutation) ResetField(name string) error { + switch name { + case authidentity.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authidentity.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case authidentity.FieldUserID: + m.ResetUserID() + return nil + case authidentity.FieldProviderType: + m.ResetProviderType() + return nil + case authidentity.FieldProviderKey: + m.ResetProviderKey() + return nil + case authidentity.FieldProviderSubject: + m.ResetProviderSubject() + return nil + case authidentity.FieldVerifiedAt: + m.ResetVerifiedAt() + return nil + case authidentity.FieldIssuer: + m.ResetIssuer() + return nil + case authidentity.FieldMetadata: + m.ResetMetadata() + return nil } - return *v, true + return fmt.Errorf("unknown AuthIdentity field %s", name) } -// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. -// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDescription is only allowed on UpdateOne operations") +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthIdentityMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.user != nil { + edges = append(edges, authidentity.EdgeUser) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDescription requires an ID field in the mutation") + if m.channels != nil { + edges = append(edges, authidentity.EdgeChannels) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldDescription: %w", err) + if m.adoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) } - return oldValue.Description, nil + return edges } -// ClearDescription clears the value of the "description" field. -func (m *ErrorPassthroughRuleMutation) ClearDescription() { - m.description = nil - m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthIdentityMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} + } + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.channels)) + for id := range m.channels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.adoption_decisions)) + for id := range m.adoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil } -// DescriptionCleared returns if the "description" field was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { - _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] - return ok +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthIdentityMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedchannels != nil { + edges = append(edges, authidentity.EdgeChannels) + } + if m.removedadoption_decisions != nil { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges } -// ResetDescription resets all changes to the "description" field. -func (m *ErrorPassthroughRuleMutation) ResetDescription() { - m.description = nil - delete(m.clearedFields, errorpassthroughrule.FieldDescription) +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthIdentityMutation) RemovedIDs(name string) []ent.Value { + switch name { + case authidentity.EdgeChannels: + ids := make([]ent.Value, 0, len(m.removedchannels)) + for id := range m.removedchannels { + ids = append(ids, id) + } + return ids + case authidentity.EdgeAdoptionDecisions: + ids := make([]ent.Value, 0, len(m.removedadoption_decisions)) + for id := range m.removedadoption_decisions { + ids = append(ids, id) + } + return ids + } + return nil } -// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. -func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { - m.predicates = append(m.predicates, ps...) +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthIdentityMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.cleareduser { + edges = append(edges, authidentity.EdgeUser) + } + if m.clearedchannels { + edges = append(edges, authidentity.EdgeChannels) + } + if m.clearedadoption_decisions { + edges = append(edges, authidentity.EdgeAdoptionDecisions) + } + return edges } -// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.ErrorPassthroughRule, len(ps)) - for i := range ps { - p[i] = ps[i] +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthIdentityMutation) EdgeCleared(name string) bool { + switch name { + case authidentity.EdgeUser: + return m.cleareduser + case authidentity.EdgeChannels: + return m.clearedchannels + case authidentity.EdgeAdoptionDecisions: + return m.clearedadoption_decisions } - m.Where(p...) + return false } -// Op returns the operation name. -func (m *ErrorPassthroughRuleMutation) Op() Op { - return m.op +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthIdentityMutation) ClearEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ClearUser() + return nil + } + return fmt.Errorf("unknown AuthIdentity unique edge %s", name) } -// SetOp allows setting the mutation operation. -func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { - m.op = op +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthIdentityMutation) ResetEdge(name string) error { + switch name { + case authidentity.EdgeUser: + m.ResetUser() + return nil + case authidentity.EdgeChannels: + m.ResetChannels() + return nil + case authidentity.EdgeAdoptionDecisions: + m.ResetAdoptionDecisions() + return nil + } + return fmt.Errorf("unknown AuthIdentity edge %s", name) } -// Type returns the node type of this mutation (ErrorPassthroughRule). -func (m *ErrorPassthroughRuleMutation) Type() string { - return m.typ +// AuthIdentityChannelMutation represents an operation that mutates the AuthIdentityChannel nodes in the graph. +type AuthIdentityChannelMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + provider_type *string + provider_key *string + channel *string + channel_app_id *string + channel_subject *string + metadata *map[string]interface{} + clearedFields map[string]struct{} + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*AuthIdentityChannel, error) + predicates []predicate.AuthIdentityChannel } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *ErrorPassthroughRuleMutation) Fields() []string { - fields := make([]string, 0, 15) - if m.created_at != nil { - fields = append(fields, errorpassthroughrule.FieldCreatedAt) +var _ ent.Mutation = (*AuthIdentityChannelMutation)(nil) + +// authidentitychannelOption allows management of the mutation configuration using functional options. +type authidentitychannelOption func(*AuthIdentityChannelMutation) + +// newAuthIdentityChannelMutation creates new mutation for the AuthIdentityChannel entity. +func newAuthIdentityChannelMutation(c config, op Op, opts ...authidentitychannelOption) *AuthIdentityChannelMutation { + m := &AuthIdentityChannelMutation{ + config: c, + op: op, + typ: TypeAuthIdentityChannel, + clearedFields: make(map[string]struct{}), } - if m.updated_at != nil { - fields = append(fields, errorpassthroughrule.FieldUpdatedAt) + for _, opt := range opts { + opt(m) } - if m.name != nil { - fields = append(fields, errorpassthroughrule.FieldName) + return m +} + +// withAuthIdentityChannelID sets the ID field of the mutation. +func withAuthIdentityChannelID(id int64) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + var ( + err error + once sync.Once + value *AuthIdentityChannel + ) + m.oldValue = func(ctx context.Context) (*AuthIdentityChannel, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AuthIdentityChannel.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - if m.enabled != nil { - fields = append(fields, errorpassthroughrule.FieldEnabled) +} + +// withAuthIdentityChannel sets the old AuthIdentityChannel of the mutation. +func withAuthIdentityChannel(node *AuthIdentityChannel) authidentitychannelOption { + return func(m *AuthIdentityChannelMutation) { + m.oldValue = func(context.Context) (*AuthIdentityChannel, error) { + return node, nil + } + m.id = &node.ID } - if m.priority != nil { - fields = append(fields, errorpassthroughrule.FieldPriority) +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AuthIdentityChannelMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AuthIdentityChannelMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") } - if m.error_codes != nil { - fields = append(fields, errorpassthroughrule.FieldErrorCodes) + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AuthIdentityChannelMutation) ID() (id int64, exists bool) { + if m.id == nil { + return } - if m.keywords != nil { - fields = append(fields, errorpassthroughrule.FieldKeywords) + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AuthIdentityChannelMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AuthIdentityChannel.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } - if m.match_mode != nil { - fields = append(fields, errorpassthroughrule.FieldMatchMode) +} + +// SetCreatedAt sets the "created_at" field. +func (m *AuthIdentityChannelMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AuthIdentityChannelMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return } - if m.platforms != nil { - fields = append(fields, errorpassthroughrule.FieldPlatforms) + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } - if m.passthrough_code != nil { - fields = append(fields, errorpassthroughrule.FieldPassthroughCode) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } - if m.response_code != nil { - fields = append(fields, errorpassthroughrule.FieldResponseCode) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - if m.passthrough_body != nil { - fields = append(fields, errorpassthroughrule.FieldPassthroughBody) + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AuthIdentityChannelMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *AuthIdentityChannelMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *AuthIdentityChannelMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return } - if m.custom_message != nil { - fields = append(fields, errorpassthroughrule.FieldCustomMessage) + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } - if m.skip_monitoring != nil { - fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } - if m.description != nil { - fields = append(fields, errorpassthroughrule.FieldDescription) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return fields + return oldValue.UpdatedAt, nil } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { - switch name { - case errorpassthroughrule.FieldCreatedAt: - return m.CreatedAt() - case errorpassthroughrule.FieldUpdatedAt: - return m.UpdatedAt() - case errorpassthroughrule.FieldName: - return m.Name() - case errorpassthroughrule.FieldEnabled: - return m.Enabled() - case errorpassthroughrule.FieldPriority: - return m.Priority() - case errorpassthroughrule.FieldErrorCodes: - return m.ErrorCodes() - case errorpassthroughrule.FieldKeywords: - return m.Keywords() - case errorpassthroughrule.FieldMatchMode: - return m.MatchMode() - case errorpassthroughrule.FieldPlatforms: - return m.Platforms() - case errorpassthroughrule.FieldPassthroughCode: - return m.PassthroughCode() - case errorpassthroughrule.FieldResponseCode: - return m.ResponseCode() - case errorpassthroughrule.FieldPassthroughBody: - return m.PassthroughBody() - case errorpassthroughrule.FieldCustomMessage: - return m.CustomMessage() - case errorpassthroughrule.FieldSkipMonitoring: - return m.SkipMonitoring() - case errorpassthroughrule.FieldDescription: - return m.Description() +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *AuthIdentityChannelMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetIdentityID sets the "identity_id" field. +func (m *AuthIdentityChannelMutation) SetIdentityID(i int64) { + m.identity = &i +} + +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *AuthIdentityChannelMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return } - return nil, false + return *v, true } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case errorpassthroughrule.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case errorpassthroughrule.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) - case errorpassthroughrule.FieldName: - return m.OldName(ctx) - case errorpassthroughrule.FieldEnabled: - return m.OldEnabled(ctx) - case errorpassthroughrule.FieldPriority: - return m.OldPriority(ctx) - case errorpassthroughrule.FieldErrorCodes: - return m.OldErrorCodes(ctx) - case errorpassthroughrule.FieldKeywords: - return m.OldKeywords(ctx) - case errorpassthroughrule.FieldMatchMode: - return m.OldMatchMode(ctx) - case errorpassthroughrule.FieldPlatforms: - return m.OldPlatforms(ctx) - case errorpassthroughrule.FieldPassthroughCode: - return m.OldPassthroughCode(ctx) - case errorpassthroughrule.FieldResponseCode: - return m.OldResponseCode(ctx) - case errorpassthroughrule.FieldPassthroughBody: - return m.OldPassthroughBody(ctx) - case errorpassthroughrule.FieldCustomMessage: - return m.OldCustomMessage(ctx) - case errorpassthroughrule.FieldSkipMonitoring: - return m.OldSkipMonitoring(ctx) - case errorpassthroughrule.FieldDescription: - return m.OldDescription(ctx) +// OldIdentityID returns the old "identity_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldIdentityID(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") } - return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdentityID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) + } + return oldValue.IdentityID, nil } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { - switch name { - case errorpassthroughrule.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreatedAt(v) - return nil - case errorpassthroughrule.FieldUpdatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUpdatedAt(v) - return nil - case errorpassthroughrule.FieldName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetName(v) - return nil - case errorpassthroughrule.FieldEnabled: - v, ok := value.(bool) +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *AuthIdentityChannelMutation) ResetIdentityID() { + m.identity = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *AuthIdentityChannelMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *AuthIdentityChannelMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *AuthIdentityChannelMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *AuthIdentityChannelMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *AuthIdentityChannelMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetChannel sets the "channel" field. +func (m *AuthIdentityChannelMutation) SetChannel(s string) { + m.channel = &s +} + +// Channel returns the value of the "channel" field in the mutation. +func (m *AuthIdentityChannelMutation) Channel() (r string, exists bool) { + v := m.channel + if v == nil { + return + } + return *v, true +} + +// OldChannel returns the old "channel" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannel(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannel: %w", err) + } + return oldValue.Channel, nil +} + +// ResetChannel resets all changes to the "channel" field. +func (m *AuthIdentityChannelMutation) ResetChannel() { + m.channel = nil +} + +// SetChannelAppID sets the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) SetChannelAppID(s string) { + m.channel_app_id = &s +} + +// ChannelAppID returns the value of the "channel_app_id" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelAppID() (r string, exists bool) { + v := m.channel_app_id + if v == nil { + return + } + return *v, true +} + +// OldChannelAppID returns the old "channel_app_id" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelAppID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelAppID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelAppID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelAppID: %w", err) + } + return oldValue.ChannelAppID, nil +} + +// ResetChannelAppID resets all changes to the "channel_app_id" field. +func (m *AuthIdentityChannelMutation) ResetChannelAppID() { + m.channel_app_id = nil +} + +// SetChannelSubject sets the "channel_subject" field. +func (m *AuthIdentityChannelMutation) SetChannelSubject(s string) { + m.channel_subject = &s +} + +// ChannelSubject returns the value of the "channel_subject" field in the mutation. +func (m *AuthIdentityChannelMutation) ChannelSubject() (r string, exists bool) { + v := m.channel_subject + if v == nil { + return + } + return *v, true +} + +// OldChannelSubject returns the old "channel_subject" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldChannelSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldChannelSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldChannelSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldChannelSubject: %w", err) + } + return oldValue.ChannelSubject, nil +} + +// ResetChannelSubject resets all changes to the "channel_subject" field. +func (m *AuthIdentityChannelMutation) ResetChannelSubject() { + m.channel_subject = nil +} + +// SetMetadata sets the "metadata" field. +func (m *AuthIdentityChannelMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} + +// Metadata returns the value of the "metadata" field in the mutation. +func (m *AuthIdentityChannelMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata + if v == nil { + return + } + return *v, true +} + +// OldMetadata returns the old "metadata" field's value of the AuthIdentityChannel entity. +// If the AuthIdentityChannel object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthIdentityChannelMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMetadata requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) + } + return oldValue.Metadata, nil +} + +// ResetMetadata resets all changes to the "metadata" field. +func (m *AuthIdentityChannelMutation) ResetMetadata() { + m.metadata = nil +} + +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *AuthIdentityChannelMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[authidentitychannel.FieldIdentityID] = struct{}{} +} + +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *AuthIdentityChannelMutation) IdentityCleared() bool { + return m.clearedidentity +} + +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *AuthIdentityChannelMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetIdentity resets all changes to the "identity" edge. +func (m *AuthIdentityChannelMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false +} + +// Where appends a list predicates to the AuthIdentityChannelMutation builder. +func (m *AuthIdentityChannelMutation) Where(ps ...predicate.AuthIdentityChannel) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AuthIdentityChannelMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthIdentityChannelMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthIdentityChannel, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AuthIdentityChannelMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthIdentityChannelMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthIdentityChannel). +func (m *AuthIdentityChannelMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthIdentityChannelMutation) Fields() []string { + fields := make([]string, 0, 9) + if m.created_at != nil { + fields = append(fields, authidentitychannel.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, authidentitychannel.FieldUpdatedAt) + } + if m.identity != nil { + fields = append(fields, authidentitychannel.FieldIdentityID) + } + if m.provider_type != nil { + fields = append(fields, authidentitychannel.FieldProviderType) + } + if m.provider_key != nil { + fields = append(fields, authidentitychannel.FieldProviderKey) + } + if m.channel != nil { + fields = append(fields, authidentitychannel.FieldChannel) + } + if m.channel_app_id != nil { + fields = append(fields, authidentitychannel.FieldChannelAppID) + } + if m.channel_subject != nil { + fields = append(fields, authidentitychannel.FieldChannelSubject) + } + if m.metadata != nil { + fields = append(fields, authidentitychannel.FieldMetadata) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthIdentityChannelMutation) Field(name string) (ent.Value, bool) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.CreatedAt() + case authidentitychannel.FieldUpdatedAt: + return m.UpdatedAt() + case authidentitychannel.FieldIdentityID: + return m.IdentityID() + case authidentitychannel.FieldProviderType: + return m.ProviderType() + case authidentitychannel.FieldProviderKey: + return m.ProviderKey() + case authidentitychannel.FieldChannel: + return m.Channel() + case authidentitychannel.FieldChannelAppID: + return m.ChannelAppID() + case authidentitychannel.FieldChannelSubject: + return m.ChannelSubject() + case authidentitychannel.FieldMetadata: + return m.Metadata() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthIdentityChannelMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authidentitychannel.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authidentitychannel.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case authidentitychannel.FieldIdentityID: + return m.OldIdentityID(ctx) + case authidentitychannel.FieldProviderType: + return m.OldProviderType(ctx) + case authidentitychannel.FieldProviderKey: + return m.OldProviderKey(ctx) + case authidentitychannel.FieldChannel: + return m.OldChannel(ctx) + case authidentitychannel.FieldChannelAppID: + return m.OldChannelAppID(ctx) + case authidentitychannel.FieldChannelSubject: + return m.OldChannelSubject(ctx) + case authidentitychannel.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown AuthIdentityChannel field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthIdentityChannelMutation) SetField(name string, value ent.Value) error { + switch name { + case authidentitychannel.FieldCreatedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetEnabled(v) + m.SetCreatedAt(v) return nil - case errorpassthroughrule.FieldPriority: - v, ok := value.(int) + case authidentitychannel.FieldUpdatedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPriority(v) + m.SetUpdatedAt(v) return nil - case errorpassthroughrule.FieldErrorCodes: - v, ok := value.([]int) + case authidentitychannel.FieldIdentityID: + v, ok := value.(int64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetErrorCodes(v) + m.SetIdentityID(v) return nil - case errorpassthroughrule.FieldKeywords: - v, ok := value.([]string) + case authidentitychannel.FieldProviderType: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetKeywords(v) + m.SetProviderType(v) return nil - case errorpassthroughrule.FieldMatchMode: + case authidentitychannel.FieldProviderKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetMatchMode(v) + m.SetProviderKey(v) return nil - case errorpassthroughrule.FieldPlatforms: - v, ok := value.([]string) + case authidentitychannel.FieldChannel: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPlatforms(v) + m.SetChannel(v) return nil - case errorpassthroughrule.FieldPassthroughCode: - v, ok := value.(bool) + case authidentitychannel.FieldChannelAppID: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPassthroughCode(v) + m.SetChannelAppID(v) return nil - case errorpassthroughrule.FieldResponseCode: - v, ok := value.(int) + case authidentitychannel.FieldChannelSubject: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseCode(v) + m.SetChannelSubject(v) return nil - case errorpassthroughrule.FieldPassthroughBody: - v, ok := value.(bool) + case authidentitychannel.FieldMetadata: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPassthroughBody(v) - return nil - case errorpassthroughrule.FieldCustomMessage: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCustomMessage(v) - return nil - case errorpassthroughrule.FieldSkipMonitoring: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSkipMonitoring(v) - return nil - case errorpassthroughrule.FieldDescription: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDescription(v) + m.SetMetadata(v) return nil } - return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *ErrorPassthroughRuleMutation) AddedFields() []string { +func (m *AuthIdentityChannelMutation) AddedFields() []string { var fields []string - if m.addpriority != nil { - fields = append(fields, errorpassthroughrule.FieldPriority) - } - if m.addresponse_code != nil { - fields = append(fields, errorpassthroughrule.FieldResponseCode) - } return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { +func (m *AuthIdentityChannelMutation) AddedField(name string) (ent.Value, bool) { switch name { - case errorpassthroughrule.FieldPriority: - return m.AddedPriority() - case errorpassthroughrule.FieldResponseCode: - return m.AddedResponseCode() } return nil, false } @@ -8028,290 +8600,205 @@ func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { +func (m *AuthIdentityChannelMutation) AddField(name string, value ent.Value) error { switch name { - case errorpassthroughrule.FieldPriority: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPriority(v) - return nil - case errorpassthroughrule.FieldResponseCode: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddResponseCode(v) - return nil } - return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) + return fmt.Errorf("unknown AuthIdentityChannel numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { - fields = append(fields, errorpassthroughrule.FieldErrorCodes) - } - if m.FieldCleared(errorpassthroughrule.FieldKeywords) { - fields = append(fields, errorpassthroughrule.FieldKeywords) - } - if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { - fields = append(fields, errorpassthroughrule.FieldPlatforms) - } - if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { - fields = append(fields, errorpassthroughrule.FieldResponseCode) - } - if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { - fields = append(fields, errorpassthroughrule.FieldCustomMessage) - } - if m.FieldCleared(errorpassthroughrule.FieldDescription) { - fields = append(fields, errorpassthroughrule.FieldDescription) - } - return fields +func (m *AuthIdentityChannelMutation) ClearedFields() []string { + return nil } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { +func (m *AuthIdentityChannelMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { - switch name { - case errorpassthroughrule.FieldErrorCodes: - m.ClearErrorCodes() - return nil - case errorpassthroughrule.FieldKeywords: - m.ClearKeywords() - return nil - case errorpassthroughrule.FieldPlatforms: - m.ClearPlatforms() - return nil - case errorpassthroughrule.FieldResponseCode: - m.ClearResponseCode() - return nil - case errorpassthroughrule.FieldCustomMessage: - m.ClearCustomMessage() - return nil - case errorpassthroughrule.FieldDescription: - m.ClearDescription() - return nil - } - return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) +func (m *AuthIdentityChannelMutation) ClearField(name string) error { + return fmt.Errorf("unknown AuthIdentityChannel nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { +func (m *AuthIdentityChannelMutation) ResetField(name string) error { switch name { - case errorpassthroughrule.FieldCreatedAt: + case authidentitychannel.FieldCreatedAt: m.ResetCreatedAt() return nil - case errorpassthroughrule.FieldUpdatedAt: + case authidentitychannel.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case errorpassthroughrule.FieldName: - m.ResetName() - return nil - case errorpassthroughrule.FieldEnabled: - m.ResetEnabled() - return nil - case errorpassthroughrule.FieldPriority: - m.ResetPriority() - return nil - case errorpassthroughrule.FieldErrorCodes: - m.ResetErrorCodes() - return nil - case errorpassthroughrule.FieldKeywords: - m.ResetKeywords() - return nil - case errorpassthroughrule.FieldMatchMode: - m.ResetMatchMode() - return nil - case errorpassthroughrule.FieldPlatforms: - m.ResetPlatforms() + case authidentitychannel.FieldIdentityID: + m.ResetIdentityID() return nil - case errorpassthroughrule.FieldPassthroughCode: - m.ResetPassthroughCode() + case authidentitychannel.FieldProviderType: + m.ResetProviderType() return nil - case errorpassthroughrule.FieldResponseCode: - m.ResetResponseCode() + case authidentitychannel.FieldProviderKey: + m.ResetProviderKey() return nil - case errorpassthroughrule.FieldPassthroughBody: - m.ResetPassthroughBody() + case authidentitychannel.FieldChannel: + m.ResetChannel() return nil - case errorpassthroughrule.FieldCustomMessage: - m.ResetCustomMessage() + case authidentitychannel.FieldChannelAppID: + m.ResetChannelAppID() return nil - case errorpassthroughrule.FieldSkipMonitoring: - m.ResetSkipMonitoring() + case authidentitychannel.FieldChannelSubject: + m.ResetChannelSubject() return nil - case errorpassthroughrule.FieldDescription: - m.ResetDescription() + case authidentitychannel.FieldMetadata: + m.ResetMetadata() return nil } - return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) + return fmt.Errorf("unknown AuthIdentityChannel field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityChannelMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.identity != nil { + edges = append(edges, authidentitychannel.EdgeIdentity) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { +func (m *AuthIdentityChannelMutation) AddedIDs(name string) []ent.Value { + switch name { + case authidentitychannel.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityChannelMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { +func (m *AuthIdentityChannelMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *AuthIdentityChannelMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedidentity { + edges = append(edges, authidentitychannel.EdgeIdentity) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { +func (m *AuthIdentityChannelMutation) EdgeCleared(name string) bool { + switch name { + case authidentitychannel.EdgeIdentity: + return m.clearedidentity + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) +func (m *AuthIdentityChannelMutation) ClearEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ClearIdentity() + return nil + } + return fmt.Errorf("unknown AuthIdentityChannel unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) -} - -// GroupMutation represents an operation that mutates the Group nodes in the graph. -type GroupMutation struct { - config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - fallback_group_id_on_invalid_request *int64 - addfallback_group_id_on_invalid_request *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - mcp_xml_inject *bool - supported_model_scopes *[]string - appendsupported_model_scopes []string - sort_order *int - addsort_order *int - allow_messages_dispatch *bool - require_oauth_only *bool - require_privacy_set *bool - default_mapped_model *string - messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group -} - -var _ ent.Mutation = (*GroupMutation)(nil) - -// groupOption allows management of the mutation configuration using functional options. -type groupOption func(*GroupMutation) - -// newGroupMutation creates new mutation for the Group entity. -func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { - m := &GroupMutation{ - config: c, - op: op, - typ: TypeGroup, - clearedFields: make(map[string]struct{}), - } - for _, opt := range opts { - opt(m) +func (m *AuthIdentityChannelMutation) ResetEdge(name string) error { + switch name { + case authidentitychannel.EdgeIdentity: + m.ResetIdentity() + return nil } - return m + return fmt.Errorf("unknown AuthIdentityChannel edge %s", name) } -// withGroupID sets the ID field of the mutation. -func withGroupID(id int64) groupOption { - return func(m *GroupMutation) { - var ( - err error - once sync.Once - value *Group +// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. +type ErrorPassthroughRuleMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + enabled *bool + priority *int + addpriority *int + error_codes *[]int + appenderror_codes []int + keywords *[]string + appendkeywords []string + match_mode *string + platforms *[]string + appendplatforms []string + passthrough_code *bool + response_code *int + addresponse_code *int + passthrough_body *bool + custom_message *string + skip_monitoring *bool + description *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*ErrorPassthroughRule, error) + predicates []predicate.ErrorPassthroughRule +} + +var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) + +// errorpassthroughruleOption allows management of the mutation configuration using functional options. +type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) + +// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. +func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { + m := &ErrorPassthroughRuleMutation{ + config: c, + op: op, + typ: TypeErrorPassthroughRule, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withErrorPassthroughRuleID sets the ID field of the mutation. +func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + var ( + err error + once sync.Once + value *ErrorPassthroughRule ) - m.oldValue = func(ctx context.Context) (*Group, error) { + m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().Group.Get(ctx, id) + value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) } }) return value, err @@ -8320,10 +8807,10 @@ func withGroupID(id int64) groupOption { } } -// withGroup sets the old Group of the mutation. -func withGroup(node *Group) groupOption { - return func(m *GroupMutation) { - m.oldValue = func(context.Context) (*Group, error) { +// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. +func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { return node, nil } m.id = &node.ID @@ -8332,7 +8819,7 @@ func withGroup(node *Group) groupOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m GroupMutation) Client() *Client { +func (m ErrorPassthroughRuleMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -8340,7 +8827,7 @@ func (m GroupMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m GroupMutation) Tx() (*Tx, error) { +func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -8351,7 +8838,7 @@ func (m GroupMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *GroupMutation) ID() (id int64, exists bool) { +func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -8362,7 +8849,7 @@ func (m *GroupMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -8371,19 +8858,19 @@ func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().Group.Query().Where(m.predicates...).IDs(ctx) + return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } // SetCreatedAt sets the "created_at" field. -func (m *GroupMutation) SetCreatedAt(t time.Time) { +func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { m.created_at = &t } // CreatedAt returns the value of the "created_at" field in the mutation. -func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { +func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { v := m.created_at if v == nil { return @@ -8391,10 +8878,10 @@ func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } @@ -8409,17 +8896,17 @@ func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err erro } // ResetCreatedAt resets all changes to the "created_at" field. -func (m *GroupMutation) ResetCreatedAt() { +func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { m.created_at = nil } // SetUpdatedAt sets the "updated_at" field. -func (m *GroupMutation) SetUpdatedAt(t time.Time) { +func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { m.updated_at = &t } // UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { +func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { v := m.updated_at if v == nil { return @@ -8427,10 +8914,10 @@ func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } @@ -8445,66 +8932,17 @@ func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err erro } // ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *GroupMutation) ResetUpdatedAt() { +func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { m.updated_at = nil } -// SetDeletedAt sets the "deleted_at" field. -func (m *GroupMutation) SetDeletedAt(t time.Time) { - m.deleted_at = &t -} - -// DeletedAt returns the value of the "deleted_at" field in the mutation. -func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) { - v := m.deleted_at - if v == nil { - return - } - return *v, true -} - -// OldDeletedAt returns the old "deleted_at" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDeletedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) - } - return oldValue.DeletedAt, nil -} - -// ClearDeletedAt clears the value of the "deleted_at" field. -func (m *GroupMutation) ClearDeletedAt() { - m.deleted_at = nil - m.clearedFields[group.FieldDeletedAt] = struct{}{} -} - -// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. -func (m *GroupMutation) DeletedAtCleared() bool { - _, ok := m.clearedFields[group.FieldDeletedAt] - return ok -} - -// ResetDeletedAt resets all changes to the "deleted_at" field. -func (m *GroupMutation) ResetDeletedAt() { - m.deleted_at = nil - delete(m.clearedFields, group.FieldDeletedAt) -} - // SetName sets the "name" field. -func (m *GroupMutation) SetName(s string) { +func (m *ErrorPassthroughRuleMutation) SetName(s string) { m.name = &s } // Name returns the value of the "name" field in the mutation. -func (m *GroupMutation) Name() (r string, exists bool) { +func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { v := m.name if v == nil { return @@ -8512,10 +8950,10 @@ func (m *GroupMutation) Name() (r string, exists bool) { return *v, true } -// OldName returns the old "name" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { +func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldName is only allowed on UpdateOne operations") } @@ -8530,3285 +8968,3061 @@ func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { } // ResetName resets all changes to the "name" field. -func (m *GroupMutation) ResetName() { +func (m *ErrorPassthroughRuleMutation) ResetName() { m.name = nil } -// SetDescription sets the "description" field. -func (m *GroupMutation) SetDescription(s string) { - m.description = &s +// SetEnabled sets the "enabled" field. +func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { + m.enabled = &b } -// Description returns the value of the "description" field in the mutation. -func (m *GroupMutation) Description() (r string, exists bool) { - v := m.description +// Enabled returns the value of the "enabled" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { + v := m.enabled if v == nil { return } return *v, true } -// OldDescription returns the old "description" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) { +func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDescription is only allowed on UpdateOne operations") + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDescription requires an ID field in the mutation") + return v, errors.New("OldEnabled requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDescription: %w", err) + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) } - return oldValue.Description, nil -} - -// ClearDescription clears the value of the "description" field. -func (m *GroupMutation) ClearDescription() { - m.description = nil - m.clearedFields[group.FieldDescription] = struct{}{} -} - -// DescriptionCleared returns if the "description" field was cleared in this mutation. -func (m *GroupMutation) DescriptionCleared() bool { - _, ok := m.clearedFields[group.FieldDescription] - return ok + return oldValue.Enabled, nil } -// ResetDescription resets all changes to the "description" field. -func (m *GroupMutation) ResetDescription() { - m.description = nil - delete(m.clearedFields, group.FieldDescription) +// ResetEnabled resets all changes to the "enabled" field. +func (m *ErrorPassthroughRuleMutation) ResetEnabled() { + m.enabled = nil } -// SetRateMultiplier sets the "rate_multiplier" field. -func (m *GroupMutation) SetRateMultiplier(f float64) { - m.rate_multiplier = &f - m.addrate_multiplier = nil +// SetPriority sets the "priority" field. +func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { + m.priority = &i + m.addpriority = nil } -// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. -func (m *GroupMutation) RateMultiplier() (r float64, exists bool) { - v := m.rate_multiplier +// Priority returns the value of the "priority" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { + v := m.priority if v == nil { return } return *v, true } -// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") + return v, errors.New("OldPriority is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRateMultiplier requires an ID field in the mutation") + return v, errors.New("OldPriority requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) + return v, fmt.Errorf("querying old value for OldPriority: %w", err) } - return oldValue.RateMultiplier, nil + return oldValue.Priority, nil } -// AddRateMultiplier adds f to the "rate_multiplier" field. -func (m *GroupMutation) AddRateMultiplier(f float64) { - if m.addrate_multiplier != nil { - *m.addrate_multiplier += f +// AddPriority adds i to the "priority" field. +func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { + if m.addpriority != nil { + *m.addpriority += i } else { - m.addrate_multiplier = &f + m.addpriority = &i } } -// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. -func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) { - v := m.addrate_multiplier +// AddedPriority returns the value that was added to the "priority" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { + v := m.addpriority if v == nil { return } return *v, true } -// ResetRateMultiplier resets all changes to the "rate_multiplier" field. -func (m *GroupMutation) ResetRateMultiplier() { - m.rate_multiplier = nil - m.addrate_multiplier = nil +// ResetPriority resets all changes to the "priority" field. +func (m *ErrorPassthroughRuleMutation) ResetPriority() { + m.priority = nil + m.addpriority = nil } -// SetIsExclusive sets the "is_exclusive" field. -func (m *GroupMutation) SetIsExclusive(b bool) { - m.is_exclusive = &b +// SetErrorCodes sets the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { + m.error_codes = &i + m.appenderror_codes = nil } -// IsExclusive returns the value of the "is_exclusive" field in the mutation. -func (m *GroupMutation) IsExclusive() (r bool, exists bool) { - v := m.is_exclusive +// ErrorCodes returns the value of the "error_codes" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { + v := m.error_codes if v == nil { return } return *v, true } -// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) { +func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations") + return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIsExclusive requires an ID field in the mutation") + return v, errors.New("OldErrorCodes requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err) + return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) } - return oldValue.IsExclusive, nil + return oldValue.ErrorCodes, nil } -// ResetIsExclusive resets all changes to the "is_exclusive" field. -func (m *GroupMutation) ResetIsExclusive() { - m.is_exclusive = nil +// AppendErrorCodes adds i to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { + m.appenderror_codes = append(m.appenderror_codes, i...) } -// SetStatus sets the "status" field. -func (m *GroupMutation) SetStatus(s string) { - m.status = &s +// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { + if len(m.appenderror_codes) == 0 { + return nil, false + } + return m.appenderror_codes, true } -// Status returns the value of the "status" field in the mutation. -func (m *GroupMutation) Status() (r string, exists bool) { - v := m.status - if v == nil { - return - } - return *v, true +// ClearErrorCodes clears the value of the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} } -// OldStatus returns the old "status" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) - } - return oldValue.Status, nil +// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] + return ok } -// ResetStatus resets all changes to the "status" field. -func (m *GroupMutation) ResetStatus() { - m.status = nil +// ResetErrorCodes resets all changes to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) } -// SetPlatform sets the "platform" field. -func (m *GroupMutation) SetPlatform(s string) { - m.platform = &s +// SetKeywords sets the "keywords" field. +func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { + m.keywords = &s + m.appendkeywords = nil } -// Platform returns the value of the "platform" field in the mutation. -func (m *GroupMutation) Platform() (r string, exists bool) { - v := m.platform +// Keywords returns the value of the "keywords" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { + v := m.keywords if v == nil { return } return *v, true } -// OldPlatform returns the old "platform" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) { +func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlatform is only allowed on UpdateOne operations") + return v, errors.New("OldKeywords is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlatform requires an ID field in the mutation") + return v, errors.New("OldKeywords requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + return v, fmt.Errorf("querying old value for OldKeywords: %w", err) } - return oldValue.Platform, nil + return oldValue.Keywords, nil } -// ResetPlatform resets all changes to the "platform" field. -func (m *GroupMutation) ResetPlatform() { - m.platform = nil +// AppendKeywords adds s to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { + m.appendkeywords = append(m.appendkeywords, s...) } -// SetSubscriptionType sets the "subscription_type" field. -func (m *GroupMutation) SetSubscriptionType(s string) { - m.subscription_type = &s +// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { + if len(m.appendkeywords) == 0 { + return nil, false + } + return m.appendkeywords, true } -// SubscriptionType returns the value of the "subscription_type" field in the mutation. -func (m *GroupMutation) SubscriptionType() (r string, exists bool) { - v := m.subscription_type +// ClearKeywords clears the value of the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ClearKeywords() { + m.keywords = nil + m.appendkeywords = nil + m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +} + +// KeywordsCleared returns if the "keywords" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] + return ok +} + +// ResetKeywords resets all changes to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ResetKeywords() { + m.keywords = nil + m.appendkeywords = nil + delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +} + +// SetMatchMode sets the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { + m.match_mode = &s +} + +// MatchMode returns the value of the "match_mode" field in the mutation. +func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { + v := m.match_mode if v == nil { return } return *v, true } -// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) { +func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations") + return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionType requires an ID field in the mutation") + return v, errors.New("OldMatchMode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err) + return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) } - return oldValue.SubscriptionType, nil + return oldValue.MatchMode, nil } -// ResetSubscriptionType resets all changes to the "subscription_type" field. -func (m *GroupMutation) ResetSubscriptionType() { - m.subscription_type = nil +// ResetMatchMode resets all changes to the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { + m.match_mode = nil } -// SetDailyLimitUsd sets the "daily_limit_usd" field. -func (m *GroupMutation) SetDailyLimitUsd(f float64) { - m.daily_limit_usd = &f - m.adddaily_limit_usd = nil +// SetPlatforms sets the "platforms" field. +func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { + m.platforms = &s + m.appendplatforms = nil } -// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. -func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) { - v := m.daily_limit_usd +// Platforms returns the value of the "platforms" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { + v := m.platforms if v == nil { return } return *v, true } -// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") + return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + return v, errors.New("OldPlatforms requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) } - return oldValue.DailyLimitUsd, nil + return oldValue.Platforms, nil } -// AddDailyLimitUsd adds f to the "daily_limit_usd" field. -func (m *GroupMutation) AddDailyLimitUsd(f float64) { - if m.adddaily_limit_usd != nil { - *m.adddaily_limit_usd += f - } else { - m.adddaily_limit_usd = &f - } +// AppendPlatforms adds s to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { + m.appendplatforms = append(m.appendplatforms, s...) } -// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. -func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) { - v := m.adddaily_limit_usd - if v == nil { - return +// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { + if len(m.appendplatforms) == 0 { + return nil, false } - return *v, true + return m.appendplatforms, true } -// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. -func (m *GroupMutation) ClearDailyLimitUsd() { - m.daily_limit_usd = nil - m.adddaily_limit_usd = nil - m.clearedFields[group.FieldDailyLimitUsd] = struct{}{} +// ClearPlatforms clears the value of the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { + m.platforms = nil + m.appendplatforms = nil + m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} } -// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) DailyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldDailyLimitUsd] +// PlatformsCleared returns if the "platforms" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] return ok } -// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. -func (m *GroupMutation) ResetDailyLimitUsd() { - m.daily_limit_usd = nil - m.adddaily_limit_usd = nil - delete(m.clearedFields, group.FieldDailyLimitUsd) +// ResetPlatforms resets all changes to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { + m.platforms = nil + m.appendplatforms = nil + delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) } -// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. -func (m *GroupMutation) SetWeeklyLimitUsd(f float64) { - m.weekly_limit_usd = &f - m.addweekly_limit_usd = nil +// SetPassthroughCode sets the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { + m.passthrough_code = &b } -// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. -func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) { - v := m.weekly_limit_usd +// PassthroughCode returns the value of the "passthrough_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { + v := m.passthrough_code if v == nil { return } return *v, true } -// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") + return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") + return v, errors.New("OldPassthroughCode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) - } - return oldValue.WeeklyLimitUsd, nil -} - -// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. -func (m *GroupMutation) AddWeeklyLimitUsd(f float64) { - if m.addweekly_limit_usd != nil { - *m.addweekly_limit_usd += f - } else { - m.addweekly_limit_usd = &f - } -} - -// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. -func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { - v := m.addweekly_limit_usd - if v == nil { - return + return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) } - return *v, true -} - -// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. -func (m *GroupMutation) ClearWeeklyLimitUsd() { - m.weekly_limit_usd = nil - m.addweekly_limit_usd = nil - m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{} -} - -// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) WeeklyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldWeeklyLimitUsd] - return ok + return oldValue.PassthroughCode, nil } -// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. -func (m *GroupMutation) ResetWeeklyLimitUsd() { - m.weekly_limit_usd = nil - m.addweekly_limit_usd = nil - delete(m.clearedFields, group.FieldWeeklyLimitUsd) +// ResetPassthroughCode resets all changes to the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { + m.passthrough_code = nil } -// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. -func (m *GroupMutation) SetMonthlyLimitUsd(f float64) { - m.monthly_limit_usd = &f - m.addmonthly_limit_usd = nil +// SetResponseCode sets the "response_code" field. +func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { + m.response_code = &i + m.addresponse_code = nil } -// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. -func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) { - v := m.monthly_limit_usd +// ResponseCode returns the value of the "response_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { + v := m.response_code if v == nil { return } return *v, true } -// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") + return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") + return v, errors.New("OldResponseCode requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) + return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) } - return oldValue.MonthlyLimitUsd, nil + return oldValue.ResponseCode, nil } -// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. -func (m *GroupMutation) AddMonthlyLimitUsd(f float64) { - if m.addmonthly_limit_usd != nil { - *m.addmonthly_limit_usd += f +// AddResponseCode adds i to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { + if m.addresponse_code != nil { + *m.addresponse_code += i } else { - m.addmonthly_limit_usd = &f + m.addresponse_code = &i } } -// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. -func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { - v := m.addmonthly_limit_usd +// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { + v := m.addresponse_code if v == nil { return } return *v, true } -// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. -func (m *GroupMutation) ClearMonthlyLimitUsd() { - m.monthly_limit_usd = nil - m.addmonthly_limit_usd = nil - m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{} +// ClearResponseCode clears the value of the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { + m.response_code = nil + m.addresponse_code = nil + m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} } -// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. -func (m *GroupMutation) MonthlyLimitUsdCleared() bool { - _, ok := m.clearedFields[group.FieldMonthlyLimitUsd] +// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] return ok } -// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. -func (m *GroupMutation) ResetMonthlyLimitUsd() { - m.monthly_limit_usd = nil - m.addmonthly_limit_usd = nil - delete(m.clearedFields, group.FieldMonthlyLimitUsd) +// ResetResponseCode resets all changes to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { + m.response_code = nil + m.addresponse_code = nil + delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) } -// SetDefaultValidityDays sets the "default_validity_days" field. -func (m *GroupMutation) SetDefaultValidityDays(i int) { - m.default_validity_days = &i - m.adddefault_validity_days = nil +// SetPassthroughBody sets the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { + m.passthrough_body = &b } -// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. -func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { - v := m.default_validity_days +// PassthroughBody returns the value of the "passthrough_body" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { + v := m.passthrough_body if v == nil { return } return *v, true } -// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { +func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") + return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") + return v, errors.New("OldPassthroughBody requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) - } - return oldValue.DefaultValidityDays, nil -} - -// AddDefaultValidityDays adds i to the "default_validity_days" field. -func (m *GroupMutation) AddDefaultValidityDays(i int) { - if m.adddefault_validity_days != nil { - *m.adddefault_validity_days += i - } else { - m.adddefault_validity_days = &i - } -} - -// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. -func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { - v := m.adddefault_validity_days - if v == nil { - return + return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) } - return *v, true + return oldValue.PassthroughBody, nil } -// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. -func (m *GroupMutation) ResetDefaultValidityDays() { - m.default_validity_days = nil - m.adddefault_validity_days = nil +// ResetPassthroughBody resets all changes to the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { + m.passthrough_body = nil } -// SetImagePrice1k sets the "image_price_1k" field. -func (m *GroupMutation) SetImagePrice1k(f float64) { - m.image_price_1k = &f - m.addimage_price_1k = nil +// SetCustomMessage sets the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { + m.custom_message = &s } -// ImagePrice1k returns the value of the "image_price_1k" field in the mutation. -func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) { - v := m.image_price_1k +// CustomMessage returns the value of the "custom_message" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { + v := m.custom_message if v == nil { return } return *v, true } -// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations") + return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice1k requires an ID field in the mutation") + return v, errors.New("OldCustomMessage requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err) + return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) } - return oldValue.ImagePrice1k, nil + return oldValue.CustomMessage, nil } -// AddImagePrice1k adds f to the "image_price_1k" field. -func (m *GroupMutation) AddImagePrice1k(f float64) { - if m.addimage_price_1k != nil { - *m.addimage_price_1k += f - } else { - m.addimage_price_1k = &f - } +// ClearCustomMessage clears the value of the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { + m.custom_message = nil + m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} } -// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation. -func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) { - v := m.addimage_price_1k - if v == nil { - return - } - return *v, true +// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] + return ok } -// ClearImagePrice1k clears the value of the "image_price_1k" field. -func (m *GroupMutation) ClearImagePrice1k() { - m.image_price_1k = nil - m.addimage_price_1k = nil - m.clearedFields[group.FieldImagePrice1k] = struct{}{} +// ResetCustomMessage resets all changes to the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { + m.custom_message = nil + delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) } -// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice1kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice1k] - return ok +// SetSkipMonitoring sets the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) SetSkipMonitoring(b bool) { + m.skip_monitoring = &b } -// ResetImagePrice1k resets all changes to the "image_price_1k" field. -func (m *GroupMutation) ResetImagePrice1k() { - m.image_price_1k = nil - m.addimage_price_1k = nil - delete(m.clearedFields, group.FieldImagePrice1k) -} - -// SetImagePrice2k sets the "image_price_2k" field. -func (m *GroupMutation) SetImagePrice2k(f float64) { - m.image_price_2k = &f - m.addimage_price_2k = nil -} - -// ImagePrice2k returns the value of the "image_price_2k" field in the mutation. -func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) { - v := m.image_price_2k +// SkipMonitoring returns the value of the "skip_monitoring" field in the mutation. +func (m *ErrorPassthroughRuleMutation) SkipMonitoring() (r bool, exists bool) { + v := m.skip_monitoring if v == nil { return } return *v, true } -// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldSkipMonitoring returns the old "skip_monitoring" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldSkipMonitoring(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations") + return v, errors.New("OldSkipMonitoring is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice2k requires an ID field in the mutation") + return v, errors.New("OldSkipMonitoring requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err) - } - return oldValue.ImagePrice2k, nil -} - -// AddImagePrice2k adds f to the "image_price_2k" field. -func (m *GroupMutation) AddImagePrice2k(f float64) { - if m.addimage_price_2k != nil { - *m.addimage_price_2k += f - } else { - m.addimage_price_2k = &f - } -} - -// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation. -func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) { - v := m.addimage_price_2k - if v == nil { - return + return v, fmt.Errorf("querying old value for OldSkipMonitoring: %w", err) } - return *v, true -} - -// ClearImagePrice2k clears the value of the "image_price_2k" field. -func (m *GroupMutation) ClearImagePrice2k() { - m.image_price_2k = nil - m.addimage_price_2k = nil - m.clearedFields[group.FieldImagePrice2k] = struct{}{} -} - -// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice2kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice2k] - return ok + return oldValue.SkipMonitoring, nil } -// ResetImagePrice2k resets all changes to the "image_price_2k" field. -func (m *GroupMutation) ResetImagePrice2k() { - m.image_price_2k = nil - m.addimage_price_2k = nil - delete(m.clearedFields, group.FieldImagePrice2k) +// ResetSkipMonitoring resets all changes to the "skip_monitoring" field. +func (m *ErrorPassthroughRuleMutation) ResetSkipMonitoring() { + m.skip_monitoring = nil } -// SetImagePrice4k sets the "image_price_4k" field. -func (m *GroupMutation) SetImagePrice4k(f float64) { - m.image_price_4k = &f - m.addimage_price_4k = nil +// SetDescription sets the "description" field. +func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { + m.description = &s } -// ImagePrice4k returns the value of the "image_price_4k" field in the mutation. -func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) { - v := m.image_price_4k +// Description returns the value of the "description" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { + v := m.description if v == nil { return } return *v, true } -// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. +// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) { +func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations") + return v, errors.New("OldDescription is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldImagePrice4k requires an ID field in the mutation") + return v, errors.New("OldDescription requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err) - } - return oldValue.ImagePrice4k, nil -} - -// AddImagePrice4k adds f to the "image_price_4k" field. -func (m *GroupMutation) AddImagePrice4k(f float64) { - if m.addimage_price_4k != nil { - *m.addimage_price_4k += f - } else { - m.addimage_price_4k = &f - } -} - -// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation. -func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) { - v := m.addimage_price_4k - if v == nil { - return + return v, fmt.Errorf("querying old value for OldDescription: %w", err) } - return *v, true + return oldValue.Description, nil } -// ClearImagePrice4k clears the value of the "image_price_4k" field. -func (m *GroupMutation) ClearImagePrice4k() { - m.image_price_4k = nil - m.addimage_price_4k = nil - m.clearedFields[group.FieldImagePrice4k] = struct{}{} +// ClearDescription clears the value of the "description" field. +func (m *ErrorPassthroughRuleMutation) ClearDescription() { + m.description = nil + m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} } -// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation. -func (m *GroupMutation) ImagePrice4kCleared() bool { - _, ok := m.clearedFields[group.FieldImagePrice4k] +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] return ok } -// ResetImagePrice4k resets all changes to the "image_price_4k" field. -func (m *GroupMutation) ResetImagePrice4k() { - m.image_price_4k = nil - m.addimage_price_4k = nil - delete(m.clearedFields, group.FieldImagePrice4k) +// ResetDescription resets all changes to the "description" field. +func (m *ErrorPassthroughRuleMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, errorpassthroughrule.FieldDescription) } -// SetClaudeCodeOnly sets the "claude_code_only" field. -func (m *GroupMutation) SetClaudeCodeOnly(b bool) { - m.claude_code_only = &b +// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. +func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { + m.predicates = append(m.predicates, ps...) } -// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. -func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { - v := m.claude_code_only - if v == nil { - return +// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ErrorPassthroughRule, len(ps)) + for i := range ps { + p[i] = ps[i] } - return *v, true + m.Where(p...) } -// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) - } - return oldValue.ClaudeCodeOnly, nil +// Op returns the operation name. +func (m *ErrorPassthroughRuleMutation) Op() Op { + return m.op } -// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. -func (m *GroupMutation) ResetClaudeCodeOnly() { - m.claude_code_only = nil +// SetOp allows setting the mutation operation. +func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { + m.op = op } -// SetFallbackGroupID sets the "fallback_group_id" field. -func (m *GroupMutation) SetFallbackGroupID(i int64) { - m.fallback_group_id = &i - m.addfallback_group_id = nil +// Type returns the node type of this mutation (ErrorPassthroughRule). +func (m *ErrorPassthroughRuleMutation) Type() string { + return m.typ } -// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. -func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { - v := m.fallback_group_id - if v == nil { - return +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ErrorPassthroughRuleMutation) Fields() []string { + fields := make([]string, 0, 15) + if m.created_at != nil { + fields = append(fields, errorpassthroughrule.FieldCreatedAt) } - return *v, true -} - -// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") + if m.updated_at != nil { + fields = append(fields, errorpassthroughrule.FieldUpdatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") + if m.name != nil { + fields = append(fields, errorpassthroughrule.FieldName) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) + if m.enabled != nil { + fields = append(fields, errorpassthroughrule.FieldEnabled) } - return oldValue.FallbackGroupID, nil -} - -// AddFallbackGroupID adds i to the "fallback_group_id" field. -func (m *GroupMutation) AddFallbackGroupID(i int64) { - if m.addfallback_group_id != nil { - *m.addfallback_group_id += i - } else { - m.addfallback_group_id = &i + if m.priority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) } -} - -// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. -func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { - v := m.addfallback_group_id - if v == nil { - return + if m.error_codes != nil { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) } - return *v, true -} - -// ClearFallbackGroupID clears the value of the "fallback_group_id" field. -func (m *GroupMutation) ClearFallbackGroupID() { - m.fallback_group_id = nil - m.addfallback_group_id = nil - m.clearedFields[group.FieldFallbackGroupID] = struct{}{} -} - -// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. -func (m *GroupMutation) FallbackGroupIDCleared() bool { - _, ok := m.clearedFields[group.FieldFallbackGroupID] - return ok -} - -// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. -func (m *GroupMutation) ResetFallbackGroupID() { - m.fallback_group_id = nil - m.addfallback_group_id = nil - delete(m.clearedFields, group.FieldFallbackGroupID) -} - -// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { - m.fallback_group_id_on_invalid_request = &i - m.addfallback_group_id_on_invalid_request = nil -} - -// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. -func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { - v := m.fallback_group_id_on_invalid_request - if v == nil { - return - } - return *v, true -} - -// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + if m.keywords != nil { + fields = append(fields, errorpassthroughrule.FieldKeywords) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + if m.match_mode != nil { + fields = append(fields, errorpassthroughrule.FieldMatchMode) } - return oldValue.FallbackGroupIDOnInvalidRequest, nil -} - -// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { - if m.addfallback_group_id_on_invalid_request != nil { - *m.addfallback_group_id_on_invalid_request += i - } else { - m.addfallback_group_id_on_invalid_request = &i + if m.platforms != nil { + fields = append(fields, errorpassthroughrule.FieldPlatforms) } -} - -// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. -func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { - v := m.addfallback_group_id_on_invalid_request - if v == nil { - return + if m.passthrough_code != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughCode) } - return *v, true -} - -// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { - m.fallback_group_id_on_invalid_request = nil - m.addfallback_group_id_on_invalid_request = nil - m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} -} - -// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. -func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { - _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] - return ok -} - -// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. -func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { - m.fallback_group_id_on_invalid_request = nil - m.addfallback_group_id_on_invalid_request = nil - delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) -} - -// SetModelRouting sets the "model_routing" field. -func (m *GroupMutation) SetModelRouting(value map[string][]int64) { - m.model_routing = &value -} - -// ModelRouting returns the value of the "model_routing" field in the mutation. -func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) { - v := m.model_routing - if v == nil { - return + if m.response_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) } - return *v, true -} - -// OldModelRouting returns the old "model_routing" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldModelRouting is only allowed on UpdateOne operations") + if m.passthrough_body != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughBody) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldModelRouting requires an ID field in the mutation") + if m.custom_message != nil { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldModelRouting: %w", err) + if m.skip_monitoring != nil { + fields = append(fields, errorpassthroughrule.FieldSkipMonitoring) } - return oldValue.ModelRouting, nil -} - -// ClearModelRouting clears the value of the "model_routing" field. -func (m *GroupMutation) ClearModelRouting() { - m.model_routing = nil - m.clearedFields[group.FieldModelRouting] = struct{}{} -} - -// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation. -func (m *GroupMutation) ModelRoutingCleared() bool { - _, ok := m.clearedFields[group.FieldModelRouting] - return ok -} - -// ResetModelRouting resets all changes to the "model_routing" field. -func (m *GroupMutation) ResetModelRouting() { - m.model_routing = nil - delete(m.clearedFields, group.FieldModelRouting) -} - -// SetModelRoutingEnabled sets the "model_routing_enabled" field. -func (m *GroupMutation) SetModelRoutingEnabled(b bool) { - m.model_routing_enabled = &b -} - -// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation. -func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) { - v := m.model_routing_enabled - if v == nil { - return + if m.description != nil { + fields = append(fields, errorpassthroughrule.FieldDescription) } - return *v, true + return fields } -// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err) +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.CreatedAt() + case errorpassthroughrule.FieldUpdatedAt: + return m.UpdatedAt() + case errorpassthroughrule.FieldName: + return m.Name() + case errorpassthroughrule.FieldEnabled: + return m.Enabled() + case errorpassthroughrule.FieldPriority: + return m.Priority() + case errorpassthroughrule.FieldErrorCodes: + return m.ErrorCodes() + case errorpassthroughrule.FieldKeywords: + return m.Keywords() + case errorpassthroughrule.FieldMatchMode: + return m.MatchMode() + case errorpassthroughrule.FieldPlatforms: + return m.Platforms() + case errorpassthroughrule.FieldPassthroughCode: + return m.PassthroughCode() + case errorpassthroughrule.FieldResponseCode: + return m.ResponseCode() + case errorpassthroughrule.FieldPassthroughBody: + return m.PassthroughBody() + case errorpassthroughrule.FieldCustomMessage: + return m.CustomMessage() + case errorpassthroughrule.FieldSkipMonitoring: + return m.SkipMonitoring() + case errorpassthroughrule.FieldDescription: + return m.Description() } - return oldValue.ModelRoutingEnabled, nil -} - -// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field. -func (m *GroupMutation) ResetModelRoutingEnabled() { - m.model_routing_enabled = nil -} - -// SetMcpXMLInject sets the "mcp_xml_inject" field. -func (m *GroupMutation) SetMcpXMLInject(b bool) { - m.mcp_xml_inject = &b + return nil, false } -// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. -func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { - v := m.mcp_xml_inject - if v == nil { - return +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case errorpassthroughrule.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case errorpassthroughrule.FieldName: + return m.OldName(ctx) + case errorpassthroughrule.FieldEnabled: + return m.OldEnabled(ctx) + case errorpassthroughrule.FieldPriority: + return m.OldPriority(ctx) + case errorpassthroughrule.FieldErrorCodes: + return m.OldErrorCodes(ctx) + case errorpassthroughrule.FieldKeywords: + return m.OldKeywords(ctx) + case errorpassthroughrule.FieldMatchMode: + return m.OldMatchMode(ctx) + case errorpassthroughrule.FieldPlatforms: + return m.OldPlatforms(ctx) + case errorpassthroughrule.FieldPassthroughCode: + return m.OldPassthroughCode(ctx) + case errorpassthroughrule.FieldResponseCode: + return m.OldResponseCode(ctx) + case errorpassthroughrule.FieldPassthroughBody: + return m.OldPassthroughBody(ctx) + case errorpassthroughrule.FieldCustomMessage: + return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldSkipMonitoring: + return m.OldSkipMonitoring(ctx) + case errorpassthroughrule.FieldDescription: + return m.OldDescription(ctx) } - return *v, true + return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) } -// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case errorpassthroughrule.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case errorpassthroughrule.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case errorpassthroughrule.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) + return nil + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPriority(v) + return nil + case errorpassthroughrule.FieldErrorCodes: + v, ok := value.([]int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorCodes(v) + return nil + case errorpassthroughrule.FieldKeywords: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKeywords(v) + return nil + case errorpassthroughrule.FieldMatchMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMatchMode(v) + return nil + case errorpassthroughrule.FieldPlatforms: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatforms(v) + return nil + case errorpassthroughrule.FieldPassthroughCode: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughCode(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseCode(v) + return nil + case errorpassthroughrule.FieldPassthroughBody: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughBody(v) + return nil + case errorpassthroughrule.FieldCustomMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCustomMessage(v) + return nil + case errorpassthroughrule.FieldSkipMonitoring: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSkipMonitoring(v) + return nil + case errorpassthroughrule.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil } - return oldValue.McpXMLInject, nil + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) } -// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. -func (m *GroupMutation) ResetMcpXMLInject() { - m.mcp_xml_inject = nil +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ErrorPassthroughRuleMutation) AddedFields() []string { + var fields []string + if m.addpriority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.addresponse_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + return fields } -// SetSupportedModelScopes sets the "supported_model_scopes" field. -func (m *GroupMutation) SetSupportedModelScopes(s []string) { - m.supported_model_scopes = &s - m.appendsupported_model_scopes = nil +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldPriority: + return m.AddedPriority() + case errorpassthroughrule.FieldResponseCode: + return m.AddedResponseCode() + } + return nil, false } -// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. -func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { - v := m.supported_model_scopes - if v == nil { - return +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPriority(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseCode(v) + return nil } - return *v, true + return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) } -// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + if m.FieldCleared(errorpassthroughrule.FieldKeywords) { + fields = append(fields, errorpassthroughrule.FieldKeywords) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { + fields = append(fields, errorpassthroughrule.FieldPlatforms) } - return oldValue.SupportedModelScopes, nil -} - -// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. -func (m *GroupMutation) AppendSupportedModelScopes(s []string) { - m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) -} - -// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. -func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { - if len(m.appendsupported_model_scopes) == 0 { - return nil, false + if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { + fields = append(fields, errorpassthroughrule.FieldResponseCode) } - return m.appendsupported_model_scopes, true + if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.FieldCleared(errorpassthroughrule.FieldDescription) { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields } -// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. -func (m *GroupMutation) ResetSupportedModelScopes() { - m.supported_model_scopes = nil - m.appendsupported_model_scopes = nil +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// SetSortOrder sets the "sort_order" field. -func (m *GroupMutation) SetSortOrder(i int) { - m.sort_order = &i - m.addsort_order = nil +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { + switch name { + case errorpassthroughrule.FieldErrorCodes: + m.ClearErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ClearKeywords() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ClearPlatforms() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ClearResponseCode() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ClearCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ClearDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) } -// SortOrder returns the value of the "sort_order" field in the mutation. -func (m *GroupMutation) SortOrder() (r int, exists bool) { - v := m.sort_order - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case errorpassthroughrule.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case errorpassthroughrule.FieldName: + m.ResetName() + return nil + case errorpassthroughrule.FieldEnabled: + m.ResetEnabled() + return nil + case errorpassthroughrule.FieldPriority: + m.ResetPriority() + return nil + case errorpassthroughrule.FieldErrorCodes: + m.ResetErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ResetKeywords() + return nil + case errorpassthroughrule.FieldMatchMode: + m.ResetMatchMode() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ResetPlatforms() + return nil + case errorpassthroughrule.FieldPassthroughCode: + m.ResetPassthroughCode() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ResetResponseCode() + return nil + case errorpassthroughrule.FieldPassthroughBody: + m.ResetPassthroughBody() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ResetCustomMessage() + return nil + case errorpassthroughrule.FieldSkipMonitoring: + m.ResetSkipMonitoring() + return nil + case errorpassthroughrule.FieldDescription: + m.ResetDescription() + return nil } - return *v, true + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) } -// OldSortOrder returns the old "sort_order" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSortOrder requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) - } - return oldValue.SortOrder, nil +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// AddSortOrder adds i to the "sort_order" field. -func (m *GroupMutation) AddSortOrder(i int) { - if m.addsort_order != nil { - *m.addsort_order += i - } else { - m.addsort_order = &i - } +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { + return nil } -// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. -func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { - v := m.addsort_order - if v == nil { - return - } - return *v, true +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// ResetSortOrder resets all changes to the "sort_order" field. -func (m *GroupMutation) ResetSortOrder() { - m.sort_order = nil - m.addsort_order = nil +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { + return nil } -// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. -func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { - m.allow_messages_dispatch = &b +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. -func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { - v := m.allow_messages_dispatch - if v == nil { - return - } - return *v, true +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { + return false } -// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) - } - return oldValue.AllowMessagesDispatch, nil +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) } -// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. -func (m *GroupMutation) ResetAllowMessagesDispatch() { - m.allow_messages_dispatch = nil +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) } -// SetRequireOauthOnly sets the "require_oauth_only" field. -func (m *GroupMutation) SetRequireOauthOnly(b bool) { - m.require_oauth_only = &b +// GroupMutation represents an operation that mutates the Group nodes in the graph. +type GroupMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + fallback_group_id_on_invalid_request *int64 + addfallback_group_id_on_invalid_request *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + mcp_xml_inject *bool + supported_model_scopes *[]string + appendsupported_model_scopes []string + sort_order *int + addsort_order *int + allow_messages_dispatch *bool + require_oauth_only *bool + require_privacy_set *bool + default_mapped_model *string + messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } -// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. -func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { - v := m.require_oauth_only - if v == nil { - return - } - return *v, true -} +var _ ent.Mutation = (*GroupMutation)(nil) -// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. -// If the Group object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") +// groupOption allows management of the mutation configuration using functional options. +type groupOption func(*GroupMutation) + +// newGroupMutation creates new mutation for the Group entity. +func newGroupMutation(c config, op Op, opts ...groupOption) *GroupMutation { + m := &GroupMutation{ + config: c, + op: op, + typ: TypeGroup, + clearedFields: make(map[string]struct{}), } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) + for _, opt := range opts { + opt(m) } - return oldValue.RequireOauthOnly, nil + return m } -// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. -func (m *GroupMutation) ResetRequireOauthOnly() { - m.require_oauth_only = nil +// withGroupID sets the ID field of the mutation. +func withGroupID(id int64) groupOption { + return func(m *GroupMutation) { + var ( + err error + once sync.Once + value *Group + ) + m.oldValue = func(ctx context.Context) (*Group, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Group.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } } -// SetRequirePrivacySet sets the "require_privacy_set" field. -func (m *GroupMutation) SetRequirePrivacySet(b bool) { - m.require_privacy_set = &b +// withGroup sets the old Group of the mutation. +func withGroup(node *Group) groupOption { + return func(m *GroupMutation) { + m.oldValue = func(context.Context) (*Group, error) { + return node, nil + } + m.id = &node.ID + } } -// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. -func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { - v := m.require_privacy_set +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m GroupMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m GroupMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *GroupMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *GroupMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Group.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *GroupMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *GroupMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// OldCreatedAt returns the old "created_at" field's value of the Group entity. // If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { +func (m *GroupMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.RequirePrivacySet, nil + return oldValue.CreatedAt, nil } -// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. -func (m *GroupMutation) ResetRequirePrivacySet() { - m.require_privacy_set = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *GroupMutation) ResetCreatedAt() { + m.created_at = nil } -// SetDefaultMappedModel sets the "default_mapped_model" field. -func (m *GroupMutation) SetDefaultMappedModel(s string) { - m.default_mapped_model = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *GroupMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. -func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { - v := m.default_mapped_model +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *GroupMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. +// OldUpdatedAt returns the old "updated_at" field's value of the Group entity. // If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.DefaultMappedModel, nil + return oldValue.UpdatedAt, nil } -// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. -func (m *GroupMutation) ResetDefaultMappedModel() { - m.default_mapped_model = nil +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *GroupMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. -func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { - m.messages_dispatch_model_config = &damdmc +// SetDeletedAt sets the "deleted_at" field. +func (m *GroupMutation) SetDeletedAt(t time.Time) { + m.deleted_at = &t } -// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. -func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { - v := m.messages_dispatch_model_config +// DeletedAt returns the value of the "deleted_at" field in the mutation. +func (m *GroupMutation) DeletedAt() (r time.Time, exists bool) { + v := m.deleted_at if v == nil { return } return *v, true } -// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. +// OldDeletedAt returns the old "deleted_at" field's value of the Group entity. // If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { +func (m *GroupMutation) OldDeletedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") + return v, errors.New("OldDeletedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") + return v, errors.New("OldDeletedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) + return v, fmt.Errorf("querying old value for OldDeletedAt: %w", err) } - return oldValue.MessagesDispatchModelConfig, nil + return oldValue.DeletedAt, nil } -// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. -func (m *GroupMutation) ResetMessagesDispatchModelConfig() { - m.messages_dispatch_model_config = nil +// ClearDeletedAt clears the value of the "deleted_at" field. +func (m *GroupMutation) ClearDeletedAt() { + m.deleted_at = nil + m.clearedFields[group.FieldDeletedAt] = struct{}{} } -// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. -func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { - if m.api_keys == nil { - m.api_keys = make(map[int64]struct{}) - } - for i := range ids { - m.api_keys[ids[i]] = struct{}{} - } +// DeletedAtCleared returns if the "deleted_at" field was cleared in this mutation. +func (m *GroupMutation) DeletedAtCleared() bool { + _, ok := m.clearedFields[group.FieldDeletedAt] + return ok } -// ClearAPIKeys clears the "api_keys" edge to the APIKey entity. -func (m *GroupMutation) ClearAPIKeys() { - m.clearedapi_keys = true +// ResetDeletedAt resets all changes to the "deleted_at" field. +func (m *GroupMutation) ResetDeletedAt() { + m.deleted_at = nil + delete(m.clearedFields, group.FieldDeletedAt) } -// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared. -func (m *GroupMutation) APIKeysCleared() bool { - return m.clearedapi_keys +// SetName sets the "name" field. +func (m *GroupMutation) SetName(s string) { + m.name = &s } -// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs. -func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) { - if m.removedapi_keys == nil { - m.removedapi_keys = make(map[int64]struct{}) - } - for i := range ids { - delete(m.api_keys, ids[i]) - m.removedapi_keys[ids[i]] = struct{}{} +// Name returns the value of the "name" field in the mutation. +func (m *GroupMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return } + return *v, true } -// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity. -func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) { - for id := range m.removedapi_keys { - ids = append(ids, id) +// OldName returns the old "name" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") } - return -} - -// APIKeysIDs returns the "api_keys" edge IDs in the mutation. -func (m *GroupMutation) APIKeysIDs() (ids []int64) { - for id := range m.api_keys { - ids = append(ids, id) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") } - return + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil } -// ResetAPIKeys resets all changes to the "api_keys" edge. -func (m *GroupMutation) ResetAPIKeys() { - m.api_keys = nil - m.clearedapi_keys = false - m.removedapi_keys = nil +// ResetName resets all changes to the "name" field. +func (m *GroupMutation) ResetName() { + m.name = nil } -// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids. -func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) { - if m.redeem_codes == nil { - m.redeem_codes = make(map[int64]struct{}) - } - for i := range ids { - m.redeem_codes[ids[i]] = struct{}{} - } +// SetDescription sets the "description" field. +func (m *GroupMutation) SetDescription(s string) { + m.description = &s } -// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity. -func (m *GroupMutation) ClearRedeemCodes() { - m.clearedredeem_codes = true -} - -// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared. -func (m *GroupMutation) RedeemCodesCleared() bool { - return m.clearedredeem_codes +// Description returns the value of the "description" field in the mutation. +func (m *GroupMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true } -// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs. -func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) { - if m.removedredeem_codes == nil { - m.removedredeem_codes = make(map[int64]struct{}) +// OldDescription returns the old "description" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDescription(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") } - for i := range ids { - delete(m.redeem_codes, ids[i]) - m.removedredeem_codes[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") } -} - -// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity. -func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) { - for id := range m.removedredeem_codes { - ids = append(ids, id) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) } - return + return oldValue.Description, nil } -// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation. -func (m *GroupMutation) RedeemCodesIDs() (ids []int64) { - for id := range m.redeem_codes { - ids = append(ids, id) - } - return +// ClearDescription clears the value of the "description" field. +func (m *GroupMutation) ClearDescription() { + m.description = nil + m.clearedFields[group.FieldDescription] = struct{}{} } -// ResetRedeemCodes resets all changes to the "redeem_codes" edge. -func (m *GroupMutation) ResetRedeemCodes() { - m.redeem_codes = nil - m.clearedredeem_codes = false - m.removedredeem_codes = nil +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *GroupMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[group.FieldDescription] + return ok } -// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids. -func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) { - if m.subscriptions == nil { - m.subscriptions = make(map[int64]struct{}) - } - for i := range ids { - m.subscriptions[ids[i]] = struct{}{} - } +// ResetDescription resets all changes to the "description" field. +func (m *GroupMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, group.FieldDescription) } -// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity. -func (m *GroupMutation) ClearSubscriptions() { - m.clearedsubscriptions = true +// SetRateMultiplier sets the "rate_multiplier" field. +func (m *GroupMutation) SetRateMultiplier(f float64) { + m.rate_multiplier = &f + m.addrate_multiplier = nil } -// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared. -func (m *GroupMutation) SubscriptionsCleared() bool { - return m.clearedsubscriptions +// RateMultiplier returns the value of the "rate_multiplier" field in the mutation. +func (m *GroupMutation) RateMultiplier() (r float64, exists bool) { + v := m.rate_multiplier + if v == nil { + return + } + return *v, true } -// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs. -func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) { - if m.removedsubscriptions == nil { - m.removedsubscriptions = make(map[int64]struct{}) +// OldRateMultiplier returns the old "rate_multiplier" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations") } - for i := range ids { - delete(m.subscriptions, ids[i]) - m.removedsubscriptions[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRateMultiplier requires an ID field in the mutation") } -} - -// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity. -func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) { - for id := range m.removedsubscriptions { - ids = append(ids, id) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err) } - return + return oldValue.RateMultiplier, nil } -// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation. -func (m *GroupMutation) SubscriptionsIDs() (ids []int64) { - for id := range m.subscriptions { - ids = append(ids, id) +// AddRateMultiplier adds f to the "rate_multiplier" field. +func (m *GroupMutation) AddRateMultiplier(f float64) { + if m.addrate_multiplier != nil { + *m.addrate_multiplier += f + } else { + m.addrate_multiplier = &f } - return } -// ResetSubscriptions resets all changes to the "subscriptions" edge. -func (m *GroupMutation) ResetSubscriptions() { - m.subscriptions = nil - m.clearedsubscriptions = false - m.removedsubscriptions = nil +// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation. +func (m *GroupMutation) AddedRateMultiplier() (r float64, exists bool) { + v := m.addrate_multiplier + if v == nil { + return + } + return *v, true } -// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. -func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { - if m.usage_logs == nil { - m.usage_logs = make(map[int64]struct{}) - } - for i := range ids { - m.usage_logs[ids[i]] = struct{}{} - } +// ResetRateMultiplier resets all changes to the "rate_multiplier" field. +func (m *GroupMutation) ResetRateMultiplier() { + m.rate_multiplier = nil + m.addrate_multiplier = nil } -// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. -func (m *GroupMutation) ClearUsageLogs() { - m.clearedusage_logs = true +// SetIsExclusive sets the "is_exclusive" field. +func (m *GroupMutation) SetIsExclusive(b bool) { + m.is_exclusive = &b } -// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. -func (m *GroupMutation) UsageLogsCleared() bool { - return m.clearedusage_logs +// IsExclusive returns the value of the "is_exclusive" field in the mutation. +func (m *GroupMutation) IsExclusive() (r bool, exists bool) { + v := m.is_exclusive + if v == nil { + return + } + return *v, true } -// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. -func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { - if m.removedusage_logs == nil { - m.removedusage_logs = make(map[int64]struct{}) +// OldIsExclusive returns the old "is_exclusive" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldIsExclusive(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsExclusive is only allowed on UpdateOne operations") } - for i := range ids { - delete(m.usage_logs, ids[i]) - m.removedusage_logs[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsExclusive requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsExclusive: %w", err) } + return oldValue.IsExclusive, nil } -// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. -func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { - for id := range m.removedusage_logs { - ids = append(ids, id) - } - return +// ResetIsExclusive resets all changes to the "is_exclusive" field. +func (m *GroupMutation) ResetIsExclusive() { + m.is_exclusive = nil } -// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. -func (m *GroupMutation) UsageLogsIDs() (ids []int64) { - for id := range m.usage_logs { - ids = append(ids, id) - } - return +// SetStatus sets the "status" field. +func (m *GroupMutation) SetStatus(s string) { + m.status = &s } -// ResetUsageLogs resets all changes to the "usage_logs" edge. -func (m *GroupMutation) ResetUsageLogs() { - m.usage_logs = nil - m.clearedusage_logs = false - m.removedusage_logs = nil +// Status returns the value of the "status" field in the mutation. +func (m *GroupMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true } -// AddAccountIDs adds the "accounts" edge to the Account entity by ids. -func (m *GroupMutation) AddAccountIDs(ids ...int64) { - if m.accounts == nil { - m.accounts = make(map[int64]struct{}) +// OldStatus returns the old "status" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") } - for i := range ids { - m.accounts[ids[i]] = struct{}{} + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) } + return oldValue.Status, nil } -// ClearAccounts clears the "accounts" edge to the Account entity. -func (m *GroupMutation) ClearAccounts() { - m.clearedaccounts = true +// ResetStatus resets all changes to the "status" field. +func (m *GroupMutation) ResetStatus() { + m.status = nil } -// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. -func (m *GroupMutation) AccountsCleared() bool { - return m.clearedaccounts +// SetPlatform sets the "platform" field. +func (m *GroupMutation) SetPlatform(s string) { + m.platform = &s } -// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. -func (m *GroupMutation) RemoveAccountIDs(ids ...int64) { - if m.removedaccounts == nil { - m.removedaccounts = make(map[int64]struct{}) - } - for i := range ids { - delete(m.accounts, ids[i]) - m.removedaccounts[ids[i]] = struct{}{} +// Platform returns the value of the "platform" field in the mutation. +func (m *GroupMutation) Platform() (r string, exists bool) { + v := m.platform + if v == nil { + return } + return *v, true } -// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. -func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) { - for id := range m.removedaccounts { - ids = append(ids, id) +// OldPlatform returns the old "platform" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldPlatform(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatform is only allowed on UpdateOne operations") } - return + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatform requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatform: %w", err) + } + return oldValue.Platform, nil } -// AccountsIDs returns the "accounts" edge IDs in the mutation. -func (m *GroupMutation) AccountsIDs() (ids []int64) { - for id := range m.accounts { - ids = append(ids, id) - } - return +// ResetPlatform resets all changes to the "platform" field. +func (m *GroupMutation) ResetPlatform() { + m.platform = nil } -// ResetAccounts resets all changes to the "accounts" edge. -func (m *GroupMutation) ResetAccounts() { - m.accounts = nil - m.clearedaccounts = false - m.removedaccounts = nil +// SetSubscriptionType sets the "subscription_type" field. +func (m *GroupMutation) SetSubscriptionType(s string) { + m.subscription_type = &s } -// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids. -func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) { - if m.allowed_users == nil { - m.allowed_users = make(map[int64]struct{}) +// SubscriptionType returns the value of the "subscription_type" field in the mutation. +func (m *GroupMutation) SubscriptionType() (r string, exists bool) { + v := m.subscription_type + if v == nil { + return } - for i := range ids { - m.allowed_users[ids[i]] = struct{}{} + return *v, true +} + +// OldSubscriptionType returns the old "subscription_type" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSubscriptionType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionType: %w", err) } + return oldValue.SubscriptionType, nil } -// ClearAllowedUsers clears the "allowed_users" edge to the User entity. -func (m *GroupMutation) ClearAllowedUsers() { - m.clearedallowed_users = true +// ResetSubscriptionType resets all changes to the "subscription_type" field. +func (m *GroupMutation) ResetSubscriptionType() { + m.subscription_type = nil } -// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared. -func (m *GroupMutation) AllowedUsersCleared() bool { - return m.clearedallowed_users +// SetDailyLimitUsd sets the "daily_limit_usd" field. +func (m *GroupMutation) SetDailyLimitUsd(f float64) { + m.daily_limit_usd = &f + m.adddaily_limit_usd = nil } -// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs. -func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) { - if m.removedallowed_users == nil { - m.removedallowed_users = make(map[int64]struct{}) - } - for i := range ids { - delete(m.allowed_users, ids[i]) - m.removedallowed_users[ids[i]] = struct{}{} +// DailyLimitUsd returns the value of the "daily_limit_usd" field in the mutation. +func (m *GroupMutation) DailyLimitUsd() (r float64, exists bool) { + v := m.daily_limit_usd + if v == nil { + return } + return *v, true } -// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity. -func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) { - for id := range m.removedallowed_users { - ids = append(ids, id) +// OldDailyLimitUsd returns the old "daily_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDailyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDailyLimitUsd is only allowed on UpdateOne operations") } - return + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDailyLimitUsd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDailyLimitUsd: %w", err) + } + return oldValue.DailyLimitUsd, nil } -// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation. -func (m *GroupMutation) AllowedUsersIDs() (ids []int64) { - for id := range m.allowed_users { - ids = append(ids, id) +// AddDailyLimitUsd adds f to the "daily_limit_usd" field. +func (m *GroupMutation) AddDailyLimitUsd(f float64) { + if m.adddaily_limit_usd != nil { + *m.adddaily_limit_usd += f + } else { + m.adddaily_limit_usd = &f } - return } -// ResetAllowedUsers resets all changes to the "allowed_users" edge. -func (m *GroupMutation) ResetAllowedUsers() { - m.allowed_users = nil - m.clearedallowed_users = false - m.removedallowed_users = nil +// AddedDailyLimitUsd returns the value that was added to the "daily_limit_usd" field in this mutation. +func (m *GroupMutation) AddedDailyLimitUsd() (r float64, exists bool) { + v := m.adddaily_limit_usd + if v == nil { + return + } + return *v, true } -// Where appends a list predicates to the GroupMutation builder. -func (m *GroupMutation) Where(ps ...predicate.Group) { - m.predicates = append(m.predicates, ps...) +// ClearDailyLimitUsd clears the value of the "daily_limit_usd" field. +func (m *GroupMutation) ClearDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + m.clearedFields[group.FieldDailyLimitUsd] = struct{}{} } -// WhereP appends storage-level predicates to the GroupMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.Group, len(ps)) - for i := range ps { - p[i] = ps[i] - } - m.Where(p...) +// DailyLimitUsdCleared returns if the "daily_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) DailyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldDailyLimitUsd] + return ok } -// Op returns the operation name. -func (m *GroupMutation) Op() Op { - return m.op +// ResetDailyLimitUsd resets all changes to the "daily_limit_usd" field. +func (m *GroupMutation) ResetDailyLimitUsd() { + m.daily_limit_usd = nil + m.adddaily_limit_usd = nil + delete(m.clearedFields, group.FieldDailyLimitUsd) } -// SetOp allows setting the mutation operation. -func (m *GroupMutation) SetOp(op Op) { - m.op = op +// SetWeeklyLimitUsd sets the "weekly_limit_usd" field. +func (m *GroupMutation) SetWeeklyLimitUsd(f float64) { + m.weekly_limit_usd = &f + m.addweekly_limit_usd = nil } -// Type returns the node type of this mutation (Group). -func (m *GroupMutation) Type() string { - return m.typ +// WeeklyLimitUsd returns the value of the "weekly_limit_usd" field in the mutation. +func (m *GroupMutation) WeeklyLimitUsd() (r float64, exists bool) { + v := m.weekly_limit_usd + if v == nil { + return + } + return *v, true } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) - if m.created_at != nil { - fields = append(fields, group.FieldCreatedAt) - } - if m.updated_at != nil { - fields = append(fields, group.FieldUpdatedAt) - } - if m.deleted_at != nil { - fields = append(fields, group.FieldDeletedAt) - } - if m.name != nil { - fields = append(fields, group.FieldName) - } - if m.description != nil { - fields = append(fields, group.FieldDescription) - } - if m.rate_multiplier != nil { - fields = append(fields, group.FieldRateMultiplier) - } - if m.is_exclusive != nil { - fields = append(fields, group.FieldIsExclusive) - } - if m.status != nil { - fields = append(fields, group.FieldStatus) - } - if m.platform != nil { - fields = append(fields, group.FieldPlatform) - } - if m.subscription_type != nil { - fields = append(fields, group.FieldSubscriptionType) - } - if m.daily_limit_usd != nil { - fields = append(fields, group.FieldDailyLimitUsd) - } - if m.weekly_limit_usd != nil { - fields = append(fields, group.FieldWeeklyLimitUsd) +// OldWeeklyLimitUsd returns the old "weekly_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldWeeklyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldWeeklyLimitUsd is only allowed on UpdateOne operations") } - if m.monthly_limit_usd != nil { - fields = append(fields, group.FieldMonthlyLimitUsd) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldWeeklyLimitUsd requires an ID field in the mutation") } - if m.default_validity_days != nil { - fields = append(fields, group.FieldDefaultValidityDays) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldWeeklyLimitUsd: %w", err) } - if m.image_price_1k != nil { - fields = append(fields, group.FieldImagePrice1k) + return oldValue.WeeklyLimitUsd, nil +} + +// AddWeeklyLimitUsd adds f to the "weekly_limit_usd" field. +func (m *GroupMutation) AddWeeklyLimitUsd(f float64) { + if m.addweekly_limit_usd != nil { + *m.addweekly_limit_usd += f + } else { + m.addweekly_limit_usd = &f } - if m.image_price_2k != nil { - fields = append(fields, group.FieldImagePrice2k) +} + +// AddedWeeklyLimitUsd returns the value that was added to the "weekly_limit_usd" field in this mutation. +func (m *GroupMutation) AddedWeeklyLimitUsd() (r float64, exists bool) { + v := m.addweekly_limit_usd + if v == nil { + return } - if m.image_price_4k != nil { - fields = append(fields, group.FieldImagePrice4k) - } - if m.claude_code_only != nil { - fields = append(fields, group.FieldClaudeCodeOnly) - } - if m.fallback_group_id != nil { - fields = append(fields, group.FieldFallbackGroupID) + return *v, true +} + +// ClearWeeklyLimitUsd clears the value of the "weekly_limit_usd" field. +func (m *GroupMutation) ClearWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + m.clearedFields[group.FieldWeeklyLimitUsd] = struct{}{} +} + +// WeeklyLimitUsdCleared returns if the "weekly_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) WeeklyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldWeeklyLimitUsd] + return ok +} + +// ResetWeeklyLimitUsd resets all changes to the "weekly_limit_usd" field. +func (m *GroupMutation) ResetWeeklyLimitUsd() { + m.weekly_limit_usd = nil + m.addweekly_limit_usd = nil + delete(m.clearedFields, group.FieldWeeklyLimitUsd) +} + +// SetMonthlyLimitUsd sets the "monthly_limit_usd" field. +func (m *GroupMutation) SetMonthlyLimitUsd(f float64) { + m.monthly_limit_usd = &f + m.addmonthly_limit_usd = nil +} + +// MonthlyLimitUsd returns the value of the "monthly_limit_usd" field in the mutation. +func (m *GroupMutation) MonthlyLimitUsd() (r float64, exists bool) { + v := m.monthly_limit_usd + if v == nil { + return } - if m.fallback_group_id_on_invalid_request != nil { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + return *v, true +} + +// OldMonthlyLimitUsd returns the old "monthly_limit_usd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMonthlyLimitUsd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMonthlyLimitUsd is only allowed on UpdateOne operations") } - if m.model_routing != nil { - fields = append(fields, group.FieldModelRouting) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMonthlyLimitUsd requires an ID field in the mutation") } - if m.model_routing_enabled != nil { - fields = append(fields, group.FieldModelRoutingEnabled) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMonthlyLimitUsd: %w", err) } - if m.mcp_xml_inject != nil { - fields = append(fields, group.FieldMcpXMLInject) + return oldValue.MonthlyLimitUsd, nil +} + +// AddMonthlyLimitUsd adds f to the "monthly_limit_usd" field. +func (m *GroupMutation) AddMonthlyLimitUsd(f float64) { + if m.addmonthly_limit_usd != nil { + *m.addmonthly_limit_usd += f + } else { + m.addmonthly_limit_usd = &f } - if m.supported_model_scopes != nil { - fields = append(fields, group.FieldSupportedModelScopes) +} + +// AddedMonthlyLimitUsd returns the value that was added to the "monthly_limit_usd" field in this mutation. +func (m *GroupMutation) AddedMonthlyLimitUsd() (r float64, exists bool) { + v := m.addmonthly_limit_usd + if v == nil { + return } - if m.sort_order != nil { - fields = append(fields, group.FieldSortOrder) + return *v, true +} + +// ClearMonthlyLimitUsd clears the value of the "monthly_limit_usd" field. +func (m *GroupMutation) ClearMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + m.clearedFields[group.FieldMonthlyLimitUsd] = struct{}{} +} + +// MonthlyLimitUsdCleared returns if the "monthly_limit_usd" field was cleared in this mutation. +func (m *GroupMutation) MonthlyLimitUsdCleared() bool { + _, ok := m.clearedFields[group.FieldMonthlyLimitUsd] + return ok +} + +// ResetMonthlyLimitUsd resets all changes to the "monthly_limit_usd" field. +func (m *GroupMutation) ResetMonthlyLimitUsd() { + m.monthly_limit_usd = nil + m.addmonthly_limit_usd = nil + delete(m.clearedFields, group.FieldMonthlyLimitUsd) +} + +// SetDefaultValidityDays sets the "default_validity_days" field. +func (m *GroupMutation) SetDefaultValidityDays(i int) { + m.default_validity_days = &i + m.adddefault_validity_days = nil +} + +// DefaultValidityDays returns the value of the "default_validity_days" field in the mutation. +func (m *GroupMutation) DefaultValidityDays() (r int, exists bool) { + v := m.default_validity_days + if v == nil { + return } - if m.allow_messages_dispatch != nil { - fields = append(fields, group.FieldAllowMessagesDispatch) + return *v, true +} + +// OldDefaultValidityDays returns the old "default_validity_days" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldDefaultValidityDays(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDefaultValidityDays is only allowed on UpdateOne operations") } - if m.require_oauth_only != nil { - fields = append(fields, group.FieldRequireOauthOnly) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDefaultValidityDays requires an ID field in the mutation") } - if m.require_privacy_set != nil { - fields = append(fields, group.FieldRequirePrivacySet) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDefaultValidityDays: %w", err) } - if m.default_mapped_model != nil { - fields = append(fields, group.FieldDefaultMappedModel) + return oldValue.DefaultValidityDays, nil +} + +// AddDefaultValidityDays adds i to the "default_validity_days" field. +func (m *GroupMutation) AddDefaultValidityDays(i int) { + if m.adddefault_validity_days != nil { + *m.adddefault_validity_days += i + } else { + m.adddefault_validity_days = &i } - if m.messages_dispatch_model_config != nil { - fields = append(fields, group.FieldMessagesDispatchModelConfig) +} + +// AddedDefaultValidityDays returns the value that was added to the "default_validity_days" field in this mutation. +func (m *GroupMutation) AddedDefaultValidityDays() (r int, exists bool) { + v := m.adddefault_validity_days + if v == nil { + return } - return fields + return *v, true } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *GroupMutation) Field(name string) (ent.Value, bool) { - switch name { - case group.FieldCreatedAt: - return m.CreatedAt() - case group.FieldUpdatedAt: - return m.UpdatedAt() - case group.FieldDeletedAt: - return m.DeletedAt() - case group.FieldName: - return m.Name() - case group.FieldDescription: - return m.Description() - case group.FieldRateMultiplier: - return m.RateMultiplier() - case group.FieldIsExclusive: - return m.IsExclusive() - case group.FieldStatus: - return m.Status() - case group.FieldPlatform: - return m.Platform() - case group.FieldSubscriptionType: - return m.SubscriptionType() - case group.FieldDailyLimitUsd: - return m.DailyLimitUsd() - case group.FieldWeeklyLimitUsd: - return m.WeeklyLimitUsd() - case group.FieldMonthlyLimitUsd: - return m.MonthlyLimitUsd() - case group.FieldDefaultValidityDays: - return m.DefaultValidityDays() - case group.FieldImagePrice1k: - return m.ImagePrice1k() - case group.FieldImagePrice2k: - return m.ImagePrice2k() - case group.FieldImagePrice4k: - return m.ImagePrice4k() - case group.FieldClaudeCodeOnly: - return m.ClaudeCodeOnly() - case group.FieldFallbackGroupID: - return m.FallbackGroupID() - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.FallbackGroupIDOnInvalidRequest() - case group.FieldModelRouting: - return m.ModelRouting() - case group.FieldModelRoutingEnabled: - return m.ModelRoutingEnabled() - case group.FieldMcpXMLInject: - return m.McpXMLInject() - case group.FieldSupportedModelScopes: - return m.SupportedModelScopes() - case group.FieldSortOrder: - return m.SortOrder() - case group.FieldAllowMessagesDispatch: - return m.AllowMessagesDispatch() - case group.FieldRequireOauthOnly: - return m.RequireOauthOnly() - case group.FieldRequirePrivacySet: - return m.RequirePrivacySet() - case group.FieldDefaultMappedModel: - return m.DefaultMappedModel() - case group.FieldMessagesDispatchModelConfig: - return m.MessagesDispatchModelConfig() +// ResetDefaultValidityDays resets all changes to the "default_validity_days" field. +func (m *GroupMutation) ResetDefaultValidityDays() { + m.default_validity_days = nil + m.adddefault_validity_days = nil +} + +// SetImagePrice1k sets the "image_price_1k" field. +func (m *GroupMutation) SetImagePrice1k(f float64) { + m.image_price_1k = &f + m.addimage_price_1k = nil +} + +// ImagePrice1k returns the value of the "image_price_1k" field in the mutation. +func (m *GroupMutation) ImagePrice1k() (r float64, exists bool) { + v := m.image_price_1k + if v == nil { + return } - return nil, false + return *v, true } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case group.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case group.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) - case group.FieldDeletedAt: - return m.OldDeletedAt(ctx) - case group.FieldName: - return m.OldName(ctx) - case group.FieldDescription: - return m.OldDescription(ctx) - case group.FieldRateMultiplier: - return m.OldRateMultiplier(ctx) - case group.FieldIsExclusive: - return m.OldIsExclusive(ctx) - case group.FieldStatus: - return m.OldStatus(ctx) - case group.FieldPlatform: - return m.OldPlatform(ctx) - case group.FieldSubscriptionType: - return m.OldSubscriptionType(ctx) - case group.FieldDailyLimitUsd: - return m.OldDailyLimitUsd(ctx) - case group.FieldWeeklyLimitUsd: - return m.OldWeeklyLimitUsd(ctx) - case group.FieldMonthlyLimitUsd: - return m.OldMonthlyLimitUsd(ctx) - case group.FieldDefaultValidityDays: - return m.OldDefaultValidityDays(ctx) - case group.FieldImagePrice1k: - return m.OldImagePrice1k(ctx) - case group.FieldImagePrice2k: - return m.OldImagePrice2k(ctx) - case group.FieldImagePrice4k: - return m.OldImagePrice4k(ctx) - case group.FieldClaudeCodeOnly: - return m.OldClaudeCodeOnly(ctx) - case group.FieldFallbackGroupID: - return m.OldFallbackGroupID(ctx) - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.OldFallbackGroupIDOnInvalidRequest(ctx) - case group.FieldModelRouting: - return m.OldModelRouting(ctx) - case group.FieldModelRoutingEnabled: - return m.OldModelRoutingEnabled(ctx) - case group.FieldMcpXMLInject: - return m.OldMcpXMLInject(ctx) - case group.FieldSupportedModelScopes: - return m.OldSupportedModelScopes(ctx) - case group.FieldSortOrder: - return m.OldSortOrder(ctx) - case group.FieldAllowMessagesDispatch: - return m.OldAllowMessagesDispatch(ctx) - case group.FieldRequireOauthOnly: - return m.OldRequireOauthOnly(ctx) - case group.FieldRequirePrivacySet: - return m.OldRequirePrivacySet(ctx) - case group.FieldDefaultMappedModel: - return m.OldDefaultMappedModel(ctx) - case group.FieldMessagesDispatchModelConfig: - return m.OldMessagesDispatchModelConfig(ctx) +// OldImagePrice1k returns the old "image_price_1k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice1k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice1k is only allowed on UpdateOne operations") } - return nil, fmt.Errorf("unknown Group field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice1k requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice1k: %w", err) + } + return oldValue.ImagePrice1k, nil } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *GroupMutation) SetField(name string, value ent.Value) error { - switch name { - case group.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreatedAt(v) - return nil - case group.FieldUpdatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUpdatedAt(v) - return nil - case group.FieldDeletedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDeletedAt(v) - return nil - case group.FieldName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetName(v) - return nil - case group.FieldDescription: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDescription(v) - return nil - case group.FieldRateMultiplier: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRateMultiplier(v) - return nil - case group.FieldIsExclusive: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetIsExclusive(v) - return nil - case group.FieldStatus: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetStatus(v) - return nil - case group.FieldPlatform: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPlatform(v) - return nil - case group.FieldSubscriptionType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionType(v) - return nil - case group.FieldDailyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDailyLimitUsd(v) - return nil - case group.FieldWeeklyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetWeeklyLimitUsd(v) - return nil - case group.FieldMonthlyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMonthlyLimitUsd(v) - return nil - case group.FieldDefaultValidityDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDefaultValidityDays(v) - return nil - case group.FieldImagePrice1k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice1k(v) - return nil - case group.FieldImagePrice2k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice2k(v) - return nil - case group.FieldImagePrice4k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetImagePrice4k(v) - return nil - case group.FieldClaudeCodeOnly: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetClaudeCodeOnly(v) - return nil - case group.FieldFallbackGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFallbackGroupID(v) - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFallbackGroupIDOnInvalidRequest(v) - return nil - case group.FieldModelRouting: - v, ok := value.(map[string][]int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetModelRouting(v) - return nil - case group.FieldModelRoutingEnabled: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetModelRoutingEnabled(v) - return nil - case group.FieldMcpXMLInject: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMcpXMLInject(v) - return nil - case group.FieldSupportedModelScopes: - v, ok := value.([]string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSupportedModelScopes(v) - return nil - case group.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSortOrder(v) - return nil - case group.FieldAllowMessagesDispatch: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAllowMessagesDispatch(v) - return nil - case group.FieldRequireOauthOnly: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRequireOauthOnly(v) - return nil - case group.FieldRequirePrivacySet: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRequirePrivacySet(v) - return nil - case group.FieldDefaultMappedModel: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDefaultMappedModel(v) - return nil - case group.FieldMessagesDispatchModelConfig: - v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetMessagesDispatchModelConfig(v) - return nil +// AddImagePrice1k adds f to the "image_price_1k" field. +func (m *GroupMutation) AddImagePrice1k(f float64) { + if m.addimage_price_1k != nil { + *m.addimage_price_1k += f + } else { + m.addimage_price_1k = &f } - return fmt.Errorf("unknown Group field %s", name) } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *GroupMutation) AddedFields() []string { - var fields []string - if m.addrate_multiplier != nil { - fields = append(fields, group.FieldRateMultiplier) - } - if m.adddaily_limit_usd != nil { - fields = append(fields, group.FieldDailyLimitUsd) +// AddedImagePrice1k returns the value that was added to the "image_price_1k" field in this mutation. +func (m *GroupMutation) AddedImagePrice1k() (r float64, exists bool) { + v := m.addimage_price_1k + if v == nil { + return } - if m.addweekly_limit_usd != nil { - fields = append(fields, group.FieldWeeklyLimitUsd) + return *v, true +} + +// ClearImagePrice1k clears the value of the "image_price_1k" field. +func (m *GroupMutation) ClearImagePrice1k() { + m.image_price_1k = nil + m.addimage_price_1k = nil + m.clearedFields[group.FieldImagePrice1k] = struct{}{} +} + +// ImagePrice1kCleared returns if the "image_price_1k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice1kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice1k] + return ok +} + +// ResetImagePrice1k resets all changes to the "image_price_1k" field. +func (m *GroupMutation) ResetImagePrice1k() { + m.image_price_1k = nil + m.addimage_price_1k = nil + delete(m.clearedFields, group.FieldImagePrice1k) +} + +// SetImagePrice2k sets the "image_price_2k" field. +func (m *GroupMutation) SetImagePrice2k(f float64) { + m.image_price_2k = &f + m.addimage_price_2k = nil +} + +// ImagePrice2k returns the value of the "image_price_2k" field in the mutation. +func (m *GroupMutation) ImagePrice2k() (r float64, exists bool) { + v := m.image_price_2k + if v == nil { + return } - if m.addmonthly_limit_usd != nil { - fields = append(fields, group.FieldMonthlyLimitUsd) + return *v, true +} + +// OldImagePrice2k returns the old "image_price_2k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice2k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice2k is only allowed on UpdateOne operations") } - if m.adddefault_validity_days != nil { - fields = append(fields, group.FieldDefaultValidityDays) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice2k requires an ID field in the mutation") } - if m.addimage_price_1k != nil { - fields = append(fields, group.FieldImagePrice1k) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice2k: %w", err) } + return oldValue.ImagePrice2k, nil +} + +// AddImagePrice2k adds f to the "image_price_2k" field. +func (m *GroupMutation) AddImagePrice2k(f float64) { if m.addimage_price_2k != nil { - fields = append(fields, group.FieldImagePrice2k) + *m.addimage_price_2k += f + } else { + m.addimage_price_2k = &f } - if m.addimage_price_4k != nil { - fields = append(fields, group.FieldImagePrice4k) +} + +// AddedImagePrice2k returns the value that was added to the "image_price_2k" field in this mutation. +func (m *GroupMutation) AddedImagePrice2k() (r float64, exists bool) { + v := m.addimage_price_2k + if v == nil { + return } - if m.addfallback_group_id != nil { - fields = append(fields, group.FieldFallbackGroupID) + return *v, true +} + +// ClearImagePrice2k clears the value of the "image_price_2k" field. +func (m *GroupMutation) ClearImagePrice2k() { + m.image_price_2k = nil + m.addimage_price_2k = nil + m.clearedFields[group.FieldImagePrice2k] = struct{}{} +} + +// ImagePrice2kCleared returns if the "image_price_2k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice2kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice2k] + return ok +} + +// ResetImagePrice2k resets all changes to the "image_price_2k" field. +func (m *GroupMutation) ResetImagePrice2k() { + m.image_price_2k = nil + m.addimage_price_2k = nil + delete(m.clearedFields, group.FieldImagePrice2k) +} + +// SetImagePrice4k sets the "image_price_4k" field. +func (m *GroupMutation) SetImagePrice4k(f float64) { + m.image_price_4k = &f + m.addimage_price_4k = nil +} + +// ImagePrice4k returns the value of the "image_price_4k" field in the mutation. +func (m *GroupMutation) ImagePrice4k() (r float64, exists bool) { + v := m.image_price_4k + if v == nil { + return } - if m.addfallback_group_id_on_invalid_request != nil { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + return *v, true +} + +// OldImagePrice4k returns the old "image_price_4k" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldImagePrice4k(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImagePrice4k is only allowed on UpdateOne operations") } - if m.addsort_order != nil { - fields = append(fields, group.FieldSortOrder) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImagePrice4k requires an ID field in the mutation") } - return fields + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImagePrice4k: %w", err) + } + return oldValue.ImagePrice4k, nil } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { - switch name { - case group.FieldRateMultiplier: - return m.AddedRateMultiplier() - case group.FieldDailyLimitUsd: - return m.AddedDailyLimitUsd() - case group.FieldWeeklyLimitUsd: - return m.AddedWeeklyLimitUsd() - case group.FieldMonthlyLimitUsd: - return m.AddedMonthlyLimitUsd() - case group.FieldDefaultValidityDays: - return m.AddedDefaultValidityDays() - case group.FieldImagePrice1k: - return m.AddedImagePrice1k() - case group.FieldImagePrice2k: - return m.AddedImagePrice2k() - case group.FieldImagePrice4k: - return m.AddedImagePrice4k() - case group.FieldFallbackGroupID: - return m.AddedFallbackGroupID() - case group.FieldFallbackGroupIDOnInvalidRequest: - return m.AddedFallbackGroupIDOnInvalidRequest() - case group.FieldSortOrder: - return m.AddedSortOrder() +// AddImagePrice4k adds f to the "image_price_4k" field. +func (m *GroupMutation) AddImagePrice4k(f float64) { + if m.addimage_price_4k != nil { + *m.addimage_price_4k += f + } else { + m.addimage_price_4k = &f } - return nil, false } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *GroupMutation) AddField(name string, value ent.Value) error { - switch name { - case group.FieldRateMultiplier: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddRateMultiplier(v) - return nil - case group.FieldDailyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddDailyLimitUsd(v) - return nil - case group.FieldWeeklyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddWeeklyLimitUsd(v) - return nil - case group.FieldMonthlyLimitUsd: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddMonthlyLimitUsd(v) - return nil - case group.FieldDefaultValidityDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddDefaultValidityDays(v) - return nil - case group.FieldImagePrice1k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice1k(v) - return nil - case group.FieldImagePrice2k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice2k(v) - return nil - case group.FieldImagePrice4k: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddImagePrice4k(v) - return nil - case group.FieldFallbackGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFallbackGroupID(v) - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFallbackGroupIDOnInvalidRequest(v) - return nil - case group.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSortOrder(v) - return nil +// AddedImagePrice4k returns the value that was added to the "image_price_4k" field in this mutation. +func (m *GroupMutation) AddedImagePrice4k() (r float64, exists bool) { + v := m.addimage_price_4k + if v == nil { + return } - return fmt.Errorf("unknown Group numeric field %s", name) + return *v, true } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *GroupMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(group.FieldDeletedAt) { - fields = append(fields, group.FieldDeletedAt) - } - if m.FieldCleared(group.FieldDescription) { - fields = append(fields, group.FieldDescription) +// ClearImagePrice4k clears the value of the "image_price_4k" field. +func (m *GroupMutation) ClearImagePrice4k() { + m.image_price_4k = nil + m.addimage_price_4k = nil + m.clearedFields[group.FieldImagePrice4k] = struct{}{} +} + +// ImagePrice4kCleared returns if the "image_price_4k" field was cleared in this mutation. +func (m *GroupMutation) ImagePrice4kCleared() bool { + _, ok := m.clearedFields[group.FieldImagePrice4k] + return ok +} + +// ResetImagePrice4k resets all changes to the "image_price_4k" field. +func (m *GroupMutation) ResetImagePrice4k() { + m.image_price_4k = nil + m.addimage_price_4k = nil + delete(m.clearedFields, group.FieldImagePrice4k) +} + +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (m *GroupMutation) SetClaudeCodeOnly(b bool) { + m.claude_code_only = &b +} + +// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. +func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { + v := m.claude_code_only + if v == nil { + return } - if m.FieldCleared(group.FieldDailyLimitUsd) { - fields = append(fields, group.FieldDailyLimitUsd) + return *v, true +} + +// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") } - if m.FieldCleared(group.FieldWeeklyLimitUsd) { - fields = append(fields, group.FieldWeeklyLimitUsd) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") } - if m.FieldCleared(group.FieldMonthlyLimitUsd) { - fields = append(fields, group.FieldMonthlyLimitUsd) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) } - if m.FieldCleared(group.FieldImagePrice1k) { - fields = append(fields, group.FieldImagePrice1k) + return oldValue.ClaudeCodeOnly, nil +} + +// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. +func (m *GroupMutation) ResetClaudeCodeOnly() { + m.claude_code_only = nil +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (m *GroupMutation) SetFallbackGroupID(i int64) { + m.fallback_group_id = &i + m.addfallback_group_id = nil +} + +// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. +func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { + v := m.fallback_group_id + if v == nil { + return } - if m.FieldCleared(group.FieldImagePrice2k) { - fields = append(fields, group.FieldImagePrice2k) + return *v, true +} + +// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") } - if m.FieldCleared(group.FieldImagePrice4k) { - fields = append(fields, group.FieldImagePrice4k) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") } - if m.FieldCleared(group.FieldFallbackGroupID) { - fields = append(fields, group.FieldFallbackGroupID) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) } - if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { - fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + return oldValue.FallbackGroupID, nil +} + +// AddFallbackGroupID adds i to the "fallback_group_id" field. +func (m *GroupMutation) AddFallbackGroupID(i int64) { + if m.addfallback_group_id != nil { + *m.addfallback_group_id += i + } else { + m.addfallback_group_id = &i } - if m.FieldCleared(group.FieldModelRouting) { - fields = append(fields, group.FieldModelRouting) +} + +// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { + v := m.addfallback_group_id + if v == nil { + return } - return fields + return *v, true } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *GroupMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (m *GroupMutation) ClearFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + m.clearedFields[group.FieldFallbackGroupID] = struct{}{} +} + +// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupID] return ok } -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *GroupMutation) ClearField(name string) error { - switch name { - case group.FieldDeletedAt: - m.ClearDeletedAt() - return nil - case group.FieldDescription: - m.ClearDescription() - return nil - case group.FieldDailyLimitUsd: - m.ClearDailyLimitUsd() - return nil - case group.FieldWeeklyLimitUsd: - m.ClearWeeklyLimitUsd() - return nil - case group.FieldMonthlyLimitUsd: - m.ClearMonthlyLimitUsd() - return nil - case group.FieldImagePrice1k: - m.ClearImagePrice1k() - return nil - case group.FieldImagePrice2k: - m.ClearImagePrice2k() - return nil - case group.FieldImagePrice4k: - m.ClearImagePrice4k() - return nil - case group.FieldFallbackGroupID: - m.ClearFallbackGroupID() - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - m.ClearFallbackGroupIDOnInvalidRequest() - return nil - case group.FieldModelRouting: - m.ClearModelRouting() - return nil +// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. +func (m *GroupMutation) ResetFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + delete(m.clearedFields, group.FieldFallbackGroupID) +} + +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { + m.fallback_group_id_on_invalid_request = &i + m.addfallback_group_id_on_invalid_request = nil +} + +// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.fallback_group_id_on_invalid_request + if v == nil { + return } - return fmt.Errorf("unknown Group nullable field %s", name) + return *v, true } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *GroupMutation) ResetField(name string) error { - switch name { - case group.FieldCreatedAt: - m.ResetCreatedAt() - return nil - case group.FieldUpdatedAt: - m.ResetUpdatedAt() - return nil - case group.FieldDeletedAt: - m.ResetDeletedAt() - return nil - case group.FieldName: - m.ResetName() - return nil - case group.FieldDescription: - m.ResetDescription() - return nil - case group.FieldRateMultiplier: - m.ResetRateMultiplier() - return nil - case group.FieldIsExclusive: - m.ResetIsExclusive() - return nil - case group.FieldStatus: - m.ResetStatus() - return nil - case group.FieldPlatform: - m.ResetPlatform() - return nil - case group.FieldSubscriptionType: - m.ResetSubscriptionType() - return nil - case group.FieldDailyLimitUsd: - m.ResetDailyLimitUsd() - return nil - case group.FieldWeeklyLimitUsd: - m.ResetWeeklyLimitUsd() - return nil - case group.FieldMonthlyLimitUsd: - m.ResetMonthlyLimitUsd() - return nil - case group.FieldDefaultValidityDays: - m.ResetDefaultValidityDays() - return nil - case group.FieldImagePrice1k: - m.ResetImagePrice1k() - return nil - case group.FieldImagePrice2k: - m.ResetImagePrice2k() - return nil - case group.FieldImagePrice4k: - m.ResetImagePrice4k() - return nil - case group.FieldClaudeCodeOnly: - m.ResetClaudeCodeOnly() - return nil - case group.FieldFallbackGroupID: - m.ResetFallbackGroupID() - return nil - case group.FieldFallbackGroupIDOnInvalidRequest: - m.ResetFallbackGroupIDOnInvalidRequest() - return nil - case group.FieldModelRouting: - m.ResetModelRouting() - return nil - case group.FieldModelRoutingEnabled: - m.ResetModelRoutingEnabled() - return nil - case group.FieldMcpXMLInject: - m.ResetMcpXMLInject() - return nil - case group.FieldSupportedModelScopes: - m.ResetSupportedModelScopes() - return nil - case group.FieldSortOrder: - m.ResetSortOrder() - return nil - case group.FieldAllowMessagesDispatch: - m.ResetAllowMessagesDispatch() - return nil - case group.FieldRequireOauthOnly: - m.ResetRequireOauthOnly() - return nil - case group.FieldRequirePrivacySet: - m.ResetRequirePrivacySet() - return nil - case group.FieldDefaultMappedModel: - m.ResetDefaultMappedModel() - return nil - case group.FieldMessagesDispatchModelConfig: - m.ResetMessagesDispatchModelConfig() - return nil +// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") } - return fmt.Errorf("unknown Group field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + } + return oldValue.FallbackGroupIDOnInvalidRequest, nil } -// AddedEdges returns all edge names that were set/added in this mutation. -func (m *GroupMutation) AddedEdges() []string { - edges := make([]string, 0, 6) - if m.api_keys != nil { - edges = append(edges, group.EdgeAPIKeys) +// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { + if m.addfallback_group_id_on_invalid_request != nil { + *m.addfallback_group_id_on_invalid_request += i + } else { + m.addfallback_group_id_on_invalid_request = &i } - if m.redeem_codes != nil { - edges = append(edges, group.EdgeRedeemCodes) +} + +// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.addfallback_group_id_on_invalid_request + if v == nil { + return } - if m.subscriptions != nil { - edges = append(edges, group.EdgeSubscriptions) + return *v, true +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} +} + +// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] + return ok +} + +// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +} + +// SetModelRouting sets the "model_routing" field. +func (m *GroupMutation) SetModelRouting(value map[string][]int64) { + m.model_routing = &value +} + +// ModelRouting returns the value of the "model_routing" field in the mutation. +func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) { + v := m.model_routing + if v == nil { + return } - if m.usage_logs != nil { - edges = append(edges, group.EdgeUsageLogs) + return *v, true +} + +// OldModelRouting returns the old "model_routing" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRouting is only allowed on UpdateOne operations") } - if m.accounts != nil { - edges = append(edges, group.EdgeAccounts) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRouting requires an ID field in the mutation") } - if m.allowed_users != nil { - edges = append(edges, group.EdgeAllowedUsers) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRouting: %w", err) } - return edges + return oldValue.ModelRouting, nil } -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *GroupMutation) AddedIDs(name string) []ent.Value { - switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.api_keys)) - for id := range m.api_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.redeem_codes)) - for id := range m.redeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.subscriptions)) - for id := range m.subscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.usage_logs)) - for id := range m.usage_logs { - ids = append(ids, id) - } - return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.accounts)) - for id := range m.accounts { - ids = append(ids, id) - } - return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.allowed_users)) - for id := range m.allowed_users { - ids = append(ids, id) - } - return ids - } - return nil +// ClearModelRouting clears the value of the "model_routing" field. +func (m *GroupMutation) ClearModelRouting() { + m.model_routing = nil + m.clearedFields[group.FieldModelRouting] = struct{}{} } -// RemovedEdges returns all edge names that were removed in this mutation. -func (m *GroupMutation) RemovedEdges() []string { - edges := make([]string, 0, 6) - if m.removedapi_keys != nil { - edges = append(edges, group.EdgeAPIKeys) +// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation. +func (m *GroupMutation) ModelRoutingCleared() bool { + _, ok := m.clearedFields[group.FieldModelRouting] + return ok +} + +// ResetModelRouting resets all changes to the "model_routing" field. +func (m *GroupMutation) ResetModelRouting() { + m.model_routing = nil + delete(m.clearedFields, group.FieldModelRouting) +} + +// SetModelRoutingEnabled sets the "model_routing_enabled" field. +func (m *GroupMutation) SetModelRoutingEnabled(b bool) { + m.model_routing_enabled = &b +} + +// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation. +func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) { + v := m.model_routing_enabled + if v == nil { + return } - if m.removedredeem_codes != nil { - edges = append(edges, group.EdgeRedeemCodes) + return *v, true +} + +// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations") } - if m.removedsubscriptions != nil { - edges = append(edges, group.EdgeSubscriptions) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation") } - if m.removedusage_logs != nil { - edges = append(edges, group.EdgeUsageLogs) - } - if m.removedaccounts != nil { - edges = append(edges, group.EdgeAccounts) - } - if m.removedallowed_users != nil { - edges = append(edges, group.EdgeAllowedUsers) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err) } - return edges + return oldValue.ModelRoutingEnabled, nil } -// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with -// the given name in this mutation. -func (m *GroupMutation) RemovedIDs(name string) []ent.Value { - switch name { - case group.EdgeAPIKeys: - ids := make([]ent.Value, 0, len(m.removedapi_keys)) - for id := range m.removedapi_keys { - ids = append(ids, id) - } - return ids - case group.EdgeRedeemCodes: - ids := make([]ent.Value, 0, len(m.removedredeem_codes)) - for id := range m.removedredeem_codes { - ids = append(ids, id) - } - return ids - case group.EdgeSubscriptions: - ids := make([]ent.Value, 0, len(m.removedsubscriptions)) - for id := range m.removedsubscriptions { - ids = append(ids, id) - } - return ids - case group.EdgeUsageLogs: - ids := make([]ent.Value, 0, len(m.removedusage_logs)) - for id := range m.removedusage_logs { - ids = append(ids, id) - } - return ids - case group.EdgeAccounts: - ids := make([]ent.Value, 0, len(m.removedaccounts)) - for id := range m.removedaccounts { - ids = append(ids, id) - } - return ids - case group.EdgeAllowedUsers: - ids := make([]ent.Value, 0, len(m.removedallowed_users)) - for id := range m.removedallowed_users { - ids = append(ids, id) - } - return ids - } - return nil +// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field. +func (m *GroupMutation) ResetModelRoutingEnabled() { + m.model_routing_enabled = nil } -// ClearedEdges returns all edge names that were cleared in this mutation. -func (m *GroupMutation) ClearedEdges() []string { - edges := make([]string, 0, 6) - if m.clearedapi_keys { - edges = append(edges, group.EdgeAPIKeys) - } - if m.clearedredeem_codes { - edges = append(edges, group.EdgeRedeemCodes) - } - if m.clearedsubscriptions { - edges = append(edges, group.EdgeSubscriptions) +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (m *GroupMutation) SetMcpXMLInject(b bool) { + m.mcp_xml_inject = &b +} + +// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. +func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { + v := m.mcp_xml_inject + if v == nil { + return } - if m.clearedusage_logs { - edges = append(edges, group.EdgeUsageLogs) + return *v, true +} + +// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") } - if m.clearedaccounts { - edges = append(edges, group.EdgeAccounts) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") } - if m.clearedallowed_users { - edges = append(edges, group.EdgeAllowedUsers) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) } - return edges + return oldValue.McpXMLInject, nil } -// EdgeCleared returns a boolean which indicates if the edge with the given name -// was cleared in this mutation. -func (m *GroupMutation) EdgeCleared(name string) bool { - switch name { - case group.EdgeAPIKeys: - return m.clearedapi_keys - case group.EdgeRedeemCodes: - return m.clearedredeem_codes - case group.EdgeSubscriptions: - return m.clearedsubscriptions - case group.EdgeUsageLogs: - return m.clearedusage_logs - case group.EdgeAccounts: - return m.clearedaccounts - case group.EdgeAllowedUsers: - return m.clearedallowed_users - } - return false +// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. +func (m *GroupMutation) ResetMcpXMLInject() { + m.mcp_xml_inject = nil } -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *GroupMutation) ClearEdge(name string) error { - switch name { +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (m *GroupMutation) SetSupportedModelScopes(s []string) { + m.supported_model_scopes = &s + m.appendsupported_model_scopes = nil +} + +// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. +func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { + v := m.supported_model_scopes + if v == nil { + return } - return fmt.Errorf("unknown Group unique edge %s", name) + return *v, true } -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *GroupMutation) ResetEdge(name string) error { - switch name { - case group.EdgeAPIKeys: - m.ResetAPIKeys() - return nil - case group.EdgeRedeemCodes: - m.ResetRedeemCodes() - return nil - case group.EdgeSubscriptions: - m.ResetSubscriptions() - return nil - case group.EdgeUsageLogs: - m.ResetUsageLogs() - return nil - case group.EdgeAccounts: - m.ResetAccounts() - return nil - case group.EdgeAllowedUsers: - m.ResetAllowedUsers() - return nil +// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") } - return fmt.Errorf("unknown Group edge %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + } + return oldValue.SupportedModelScopes, nil } -// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. -type IdempotencyRecordMutation struct { - config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - scope *string - idempotency_key_hash *string - request_fingerprint *string - status *string - response_status *int - addresponse_status *int - response_body *string - error_reason *string - locked_until *time.Time - expires_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*IdempotencyRecord, error) - predicates []predicate.IdempotencyRecord +// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. +func (m *GroupMutation) AppendSupportedModelScopes(s []string) { + m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) } -var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) +// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. +func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { + if len(m.appendsupported_model_scopes) == 0 { + return nil, false + } + return m.appendsupported_model_scopes, true +} -// idempotencyrecordOption allows management of the mutation configuration using functional options. -type idempotencyrecordOption func(*IdempotencyRecordMutation) +// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. +func (m *GroupMutation) ResetSupportedModelScopes() { + m.supported_model_scopes = nil + m.appendsupported_model_scopes = nil +} -// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. -func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { - m := &IdempotencyRecordMutation{ - config: c, - op: op, - typ: TypeIdempotencyRecord, - clearedFields: make(map[string]struct{}), - } - for _, opt := range opts { - opt(m) - } - return m +// SetSortOrder sets the "sort_order" field. +func (m *GroupMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil } -// withIdempotencyRecordID sets the ID field of the mutation. -func withIdempotencyRecordID(id int64) idempotencyrecordOption { - return func(m *IdempotencyRecordMutation) { - var ( - err error - once sync.Once - value *IdempotencyRecord - ) - m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { - once.Do(func() { - if m.done { - err = errors.New("querying old values post mutation is not allowed") - } else { - value, err = m.Client().IdempotencyRecord.Get(ctx, id) - } - }) - return value, err - } - m.id = &id +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *GroupMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return } + return *v, true } -// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. -func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { - return func(m *IdempotencyRecordMutation) { - m.oldValue = func(context.Context) (*IdempotencyRecord, error) { - return node, nil - } - m.id = &node.ID +// OldSortOrder returns the old "sort_order" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil } -// Client returns a new `ent.Client` from the mutation. If the mutation was -// executed in a transaction (ent.Tx), a transactional client is returned. -func (m IdempotencyRecordMutation) Client() *Client { - client := &Client{config: m.config} - client.init() - return client -} - -// Tx returns an `ent.Tx` for mutations that were executed in transactions; -// it returns an error otherwise. -func (m IdempotencyRecordMutation) Tx() (*Tx, error) { - if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") +// AddSortOrder adds i to the "sort_order" field. +func (m *GroupMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i } - tx := &Tx{config: m.config} - tx.init() - return tx, nil } -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { - if m.id == nil { +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { return } - return *m.id, true + return *v, true } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []int64{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) - } +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *GroupMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil } -// SetCreatedAt sets the "created_at" field. -func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetAllowMessagesDispatch sets the "allow_messages_dispatch" field. +func (m *GroupMutation) SetAllowMessagesDispatch(b bool) { + m.allow_messages_dispatch = &b } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// AllowMessagesDispatch returns the value of the "allow_messages_dispatch" field in the mutation. +func (m *GroupMutation) AllowMessagesDispatch() (r bool, exists bool) { + v := m.allow_messages_dispatch if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldAllowMessagesDispatch returns the old "allow_messages_dispatch" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *GroupMutation) OldAllowMessagesDispatch(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldAllowMessagesDispatch is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldAllowMessagesDispatch requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldAllowMessagesDispatch: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.AllowMessagesDispatch, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *IdempotencyRecordMutation) ResetCreatedAt() { - m.created_at = nil +// ResetAllowMessagesDispatch resets all changes to the "allow_messages_dispatch" field. +func (m *GroupMutation) ResetAllowMessagesDispatch() { + m.allow_messages_dispatch = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetRequireOauthOnly sets the "require_oauth_only" field. +func (m *GroupMutation) SetRequireOauthOnly(b bool) { + m.require_oauth_only = &b } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// RequireOauthOnly returns the value of the "require_oauth_only" field in the mutation. +func (m *GroupMutation) RequireOauthOnly() (r bool, exists bool) { + v := m.require_oauth_only if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldRequireOauthOnly returns the old "require_oauth_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *GroupMutation) OldRequireOauthOnly(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldRequireOauthOnly is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldRequireOauthOnly requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldRequireOauthOnly: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.RequireOauthOnly, nil } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *IdempotencyRecordMutation) ResetUpdatedAt() { - m.updated_at = nil +// ResetRequireOauthOnly resets all changes to the "require_oauth_only" field. +func (m *GroupMutation) ResetRequireOauthOnly() { + m.require_oauth_only = nil } -// SetScope sets the "scope" field. -func (m *IdempotencyRecordMutation) SetScope(s string) { - m.scope = &s +// SetRequirePrivacySet sets the "require_privacy_set" field. +func (m *GroupMutation) SetRequirePrivacySet(b bool) { + m.require_privacy_set = &b } -// Scope returns the value of the "scope" field in the mutation. -func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { - v := m.scope +// RequirePrivacySet returns the value of the "require_privacy_set" field in the mutation. +func (m *GroupMutation) RequirePrivacySet() (r bool, exists bool) { + v := m.require_privacy_set if v == nil { return } return *v, true } -// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldRequirePrivacySet returns the old "require_privacy_set" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldRequirePrivacySet(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldScope is only allowed on UpdateOne operations") + return v, errors.New("OldRequirePrivacySet is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldScope requires an ID field in the mutation") + return v, errors.New("OldRequirePrivacySet requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldScope: %w", err) + return v, fmt.Errorf("querying old value for OldRequirePrivacySet: %w", err) } - return oldValue.Scope, nil + return oldValue.RequirePrivacySet, nil } -// ResetScope resets all changes to the "scope" field. -func (m *IdempotencyRecordMutation) ResetScope() { - m.scope = nil +// ResetRequirePrivacySet resets all changes to the "require_privacy_set" field. +func (m *GroupMutation) ResetRequirePrivacySet() { + m.require_privacy_set = nil } -// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. -func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { - m.idempotency_key_hash = &s +// SetDefaultMappedModel sets the "default_mapped_model" field. +func (m *GroupMutation) SetDefaultMappedModel(s string) { + m.default_mapped_model = &s } -// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. -func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { - v := m.idempotency_key_hash +// DefaultMappedModel returns the value of the "default_mapped_model" field in the mutation. +func (m *GroupMutation) DefaultMappedModel() (r string, exists bool) { + v := m.default_mapped_model if v == nil { return } return *v, true } -// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldDefaultMappedModel returns the old "default_mapped_model" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldDefaultMappedModel(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + return v, errors.New("OldDefaultMappedModel is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + return v, errors.New("OldDefaultMappedModel requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + return v, fmt.Errorf("querying old value for OldDefaultMappedModel: %w", err) } - return oldValue.IdempotencyKeyHash, nil + return oldValue.DefaultMappedModel, nil } -// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. -func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { - m.idempotency_key_hash = nil +// ResetDefaultMappedModel resets all changes to the "default_mapped_model" field. +func (m *GroupMutation) ResetDefaultMappedModel() { + m.default_mapped_model = nil } -// SetRequestFingerprint sets the "request_fingerprint" field. -func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { - m.request_fingerprint = &s +// SetMessagesDispatchModelConfig sets the "messages_dispatch_model_config" field. +func (m *GroupMutation) SetMessagesDispatchModelConfig(damdmc domain.OpenAIMessagesDispatchModelConfig) { + m.messages_dispatch_model_config = &damdmc } -// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. -func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { - v := m.request_fingerprint +// MessagesDispatchModelConfig returns the value of the "messages_dispatch_model_config" field in the mutation. +func (m *GroupMutation) MessagesDispatchModelConfig() (r domain.OpenAIMessagesDispatchModelConfig, exists bool) { + v := m.messages_dispatch_model_config if v == nil { return } return *v, true } -// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// OldMessagesDispatchModelConfig returns the old "messages_dispatch_model_config" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { +func (m *GroupMutation) OldMessagesDispatchModelConfig(ctx context.Context) (v domain.OpenAIMessagesDispatchModelConfig, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + return v, errors.New("OldMessagesDispatchModelConfig is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + return v, errors.New("OldMessagesDispatchModelConfig requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + return v, fmt.Errorf("querying old value for OldMessagesDispatchModelConfig: %w", err) } - return oldValue.RequestFingerprint, nil + return oldValue.MessagesDispatchModelConfig, nil } -// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. -func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { - m.request_fingerprint = nil +// ResetMessagesDispatchModelConfig resets all changes to the "messages_dispatch_model_config" field. +func (m *GroupMutation) ResetMessagesDispatchModelConfig() { + m.messages_dispatch_model_config = nil } -// SetStatus sets the "status" field. -func (m *IdempotencyRecordMutation) SetStatus(s string) { - m.status = &s +// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. +func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { + if m.api_keys == nil { + m.api_keys = make(map[int64]struct{}) + } + for i := range ids { + m.api_keys[ids[i]] = struct{}{} + } } -// Status returns the value of the "status" field in the mutation. -func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { - v := m.status - if v == nil { - return - } - return *v, true +// ClearAPIKeys clears the "api_keys" edge to the APIKey entity. +func (m *GroupMutation) ClearAPIKeys() { + m.clearedapi_keys = true } -// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") +// APIKeysCleared reports if the "api_keys" edge to the APIKey entity was cleared. +func (m *GroupMutation) APIKeysCleared() bool { + return m.clearedapi_keys +} + +// RemoveAPIKeyIDs removes the "api_keys" edge to the APIKey entity by IDs. +func (m *GroupMutation) RemoveAPIKeyIDs(ids ...int64) { + if m.removedapi_keys == nil { + m.removedapi_keys = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") + for i := range ids { + delete(m.api_keys, ids[i]) + m.removedapi_keys[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) +} + +// RemovedAPIKeys returns the removed IDs of the "api_keys" edge to the APIKey entity. +func (m *GroupMutation) RemovedAPIKeysIDs() (ids []int64) { + for id := range m.removedapi_keys { + ids = append(ids, id) } - return oldValue.Status, nil + return } -// ResetStatus resets all changes to the "status" field. -func (m *IdempotencyRecordMutation) ResetStatus() { - m.status = nil +// APIKeysIDs returns the "api_keys" edge IDs in the mutation. +func (m *GroupMutation) APIKeysIDs() (ids []int64) { + for id := range m.api_keys { + ids = append(ids, id) + } + return } -// SetResponseStatus sets the "response_status" field. -func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { - m.response_status = &i - m.addresponse_status = nil +// ResetAPIKeys resets all changes to the "api_keys" edge. +func (m *GroupMutation) ResetAPIKeys() { + m.api_keys = nil + m.clearedapi_keys = false + m.removedapi_keys = nil } -// ResponseStatus returns the value of the "response_status" field in the mutation. -func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { - v := m.response_status - if v == nil { - return +// AddRedeemCodeIDs adds the "redeem_codes" edge to the RedeemCode entity by ids. +func (m *GroupMutation) AddRedeemCodeIDs(ids ...int64) { + if m.redeem_codes == nil { + m.redeem_codes = make(map[int64]struct{}) + } + for i := range ids { + m.redeem_codes[ids[i]] = struct{}{} } - return *v, true } -// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseStatus requires an ID field in the mutation") +// ClearRedeemCodes clears the "redeem_codes" edge to the RedeemCode entity. +func (m *GroupMutation) ClearRedeemCodes() { + m.clearedredeem_codes = true +} + +// RedeemCodesCleared reports if the "redeem_codes" edge to the RedeemCode entity was cleared. +func (m *GroupMutation) RedeemCodesCleared() bool { + return m.clearedredeem_codes +} + +// RemoveRedeemCodeIDs removes the "redeem_codes" edge to the RedeemCode entity by IDs. +func (m *GroupMutation) RemoveRedeemCodeIDs(ids ...int64) { + if m.removedredeem_codes == nil { + m.removedredeem_codes = make(map[int64]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + for i := range ids { + delete(m.redeem_codes, ids[i]) + m.removedredeem_codes[ids[i]] = struct{}{} } - return oldValue.ResponseStatus, nil } -// AddResponseStatus adds i to the "response_status" field. -func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { - if m.addresponse_status != nil { - *m.addresponse_status += i - } else { - m.addresponse_status = &i +// RemovedRedeemCodes returns the removed IDs of the "redeem_codes" edge to the RedeemCode entity. +func (m *GroupMutation) RemovedRedeemCodesIDs() (ids []int64) { + for id := range m.removedredeem_codes { + ids = append(ids, id) } + return } -// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. -func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { - v := m.addresponse_status - if v == nil { - return +// RedeemCodesIDs returns the "redeem_codes" edge IDs in the mutation. +func (m *GroupMutation) RedeemCodesIDs() (ids []int64) { + for id := range m.redeem_codes { + ids = append(ids, id) } - return *v, true + return } -// ClearResponseStatus clears the value of the "response_status" field. -func (m *IdempotencyRecordMutation) ClearResponseStatus() { - m.response_status = nil - m.addresponse_status = nil - m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} +// ResetRedeemCodes resets all changes to the "redeem_codes" edge. +func (m *GroupMutation) ResetRedeemCodes() { + m.redeem_codes = nil + m.clearedredeem_codes = false + m.removedredeem_codes = nil } -// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] - return ok +// AddSubscriptionIDs adds the "subscriptions" edge to the UserSubscription entity by ids. +func (m *GroupMutation) AddSubscriptionIDs(ids ...int64) { + if m.subscriptions == nil { + m.subscriptions = make(map[int64]struct{}) + } + for i := range ids { + m.subscriptions[ids[i]] = struct{}{} + } } -// ResetResponseStatus resets all changes to the "response_status" field. -func (m *IdempotencyRecordMutation) ResetResponseStatus() { - m.response_status = nil - m.addresponse_status = nil - delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +// ClearSubscriptions clears the "subscriptions" edge to the UserSubscription entity. +func (m *GroupMutation) ClearSubscriptions() { + m.clearedsubscriptions = true } -// SetResponseBody sets the "response_body" field. -func (m *IdempotencyRecordMutation) SetResponseBody(s string) { - m.response_body = &s +// SubscriptionsCleared reports if the "subscriptions" edge to the UserSubscription entity was cleared. +func (m *GroupMutation) SubscriptionsCleared() bool { + return m.clearedsubscriptions } -// ResponseBody returns the value of the "response_body" field in the mutation. -func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { - v := m.response_body - if v == nil { - return +// RemoveSubscriptionIDs removes the "subscriptions" edge to the UserSubscription entity by IDs. +func (m *GroupMutation) RemoveSubscriptionIDs(ids ...int64) { + if m.removedsubscriptions == nil { + m.removedsubscriptions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.subscriptions, ids[i]) + m.removedsubscriptions[ids[i]] = struct{}{} } - return *v, true } -// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldResponseBody requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) +// RemovedSubscriptions returns the removed IDs of the "subscriptions" edge to the UserSubscription entity. +func (m *GroupMutation) RemovedSubscriptionsIDs() (ids []int64) { + for id := range m.removedsubscriptions { + ids = append(ids, id) } - return oldValue.ResponseBody, nil + return } -// ClearResponseBody clears the value of the "response_body" field. -func (m *IdempotencyRecordMutation) ClearResponseBody() { - m.response_body = nil - m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +// SubscriptionsIDs returns the "subscriptions" edge IDs in the mutation. +func (m *GroupMutation) SubscriptionsIDs() (ids []int64) { + for id := range m.subscriptions { + ids = append(ids, id) + } + return } -// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] - return ok +// ResetSubscriptions resets all changes to the "subscriptions" edge. +func (m *GroupMutation) ResetSubscriptions() { + m.subscriptions = nil + m.clearedsubscriptions = false + m.removedsubscriptions = nil } -// ResetResponseBody resets all changes to the "response_body" field. -func (m *IdempotencyRecordMutation) ResetResponseBody() { - m.response_body = nil - delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +// AddUsageLogIDs adds the "usage_logs" edge to the UsageLog entity by ids. +func (m *GroupMutation) AddUsageLogIDs(ids ...int64) { + if m.usage_logs == nil { + m.usage_logs = make(map[int64]struct{}) + } + for i := range ids { + m.usage_logs[ids[i]] = struct{}{} + } } -// SetErrorReason sets the "error_reason" field. -func (m *IdempotencyRecordMutation) SetErrorReason(s string) { - m.error_reason = &s +// ClearUsageLogs clears the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) ClearUsageLogs() { + m.clearedusage_logs = true } -// ErrorReason returns the value of the "error_reason" field in the mutation. -func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { - v := m.error_reason - if v == nil { - return - } - return *v, true +// UsageLogsCleared reports if the "usage_logs" edge to the UsageLog entity was cleared. +func (m *GroupMutation) UsageLogsCleared() bool { + return m.clearedusage_logs } -// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") +// RemoveUsageLogIDs removes the "usage_logs" edge to the UsageLog entity by IDs. +func (m *GroupMutation) RemoveUsageLogIDs(ids ...int64) { + if m.removedusage_logs == nil { + m.removedusage_logs = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldErrorReason requires an ID field in the mutation") + for i := range ids { + delete(m.usage_logs, ids[i]) + m.removedusage_logs[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) +} + +// RemovedUsageLogs returns the removed IDs of the "usage_logs" edge to the UsageLog entity. +func (m *GroupMutation) RemovedUsageLogsIDs() (ids []int64) { + for id := range m.removedusage_logs { + ids = append(ids, id) } - return oldValue.ErrorReason, nil + return } -// ClearErrorReason clears the value of the "error_reason" field. -func (m *IdempotencyRecordMutation) ClearErrorReason() { - m.error_reason = nil - m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +// UsageLogsIDs returns the "usage_logs" edge IDs in the mutation. +func (m *GroupMutation) UsageLogsIDs() (ids []int64) { + for id := range m.usage_logs { + ids = append(ids, id) + } + return } -// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] - return ok +// ResetUsageLogs resets all changes to the "usage_logs" edge. +func (m *GroupMutation) ResetUsageLogs() { + m.usage_logs = nil + m.clearedusage_logs = false + m.removedusage_logs = nil } -// ResetErrorReason resets all changes to the "error_reason" field. -func (m *IdempotencyRecordMutation) ResetErrorReason() { - m.error_reason = nil - delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +// AddAccountIDs adds the "accounts" edge to the Account entity by ids. +func (m *GroupMutation) AddAccountIDs(ids ...int64) { + if m.accounts == nil { + m.accounts = make(map[int64]struct{}) + } + for i := range ids { + m.accounts[ids[i]] = struct{}{} + } } -// SetLockedUntil sets the "locked_until" field. -func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { - m.locked_until = &t +// ClearAccounts clears the "accounts" edge to the Account entity. +func (m *GroupMutation) ClearAccounts() { + m.clearedaccounts = true } -// LockedUntil returns the value of the "locked_until" field in the mutation. -func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { - v := m.locked_until - if v == nil { - return - } - return *v, true +// AccountsCleared reports if the "accounts" edge to the Account entity was cleared. +func (m *GroupMutation) AccountsCleared() bool { + return m.clearedaccounts } -// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") +// RemoveAccountIDs removes the "accounts" edge to the Account entity by IDs. +func (m *GroupMutation) RemoveAccountIDs(ids ...int64) { + if m.removedaccounts == nil { + m.removedaccounts = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLockedUntil requires an ID field in the mutation") + for i := range ids { + delete(m.accounts, ids[i]) + m.removedaccounts[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) +} + +// RemovedAccounts returns the removed IDs of the "accounts" edge to the Account entity. +func (m *GroupMutation) RemovedAccountsIDs() (ids []int64) { + for id := range m.removedaccounts { + ids = append(ids, id) } - return oldValue.LockedUntil, nil + return } -// ClearLockedUntil clears the value of the "locked_until" field. -func (m *IdempotencyRecordMutation) ClearLockedUntil() { - m.locked_until = nil - m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +// AccountsIDs returns the "accounts" edge IDs in the mutation. +func (m *GroupMutation) AccountsIDs() (ids []int64) { + for id := range m.accounts { + ids = append(ids, id) + } + return } -// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. -func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { - _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] - return ok +// ResetAccounts resets all changes to the "accounts" edge. +func (m *GroupMutation) ResetAccounts() { + m.accounts = nil + m.clearedaccounts = false + m.removedaccounts = nil } -// ResetLockedUntil resets all changes to the "locked_until" field. -func (m *IdempotencyRecordMutation) ResetLockedUntil() { - m.locked_until = nil - delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +// AddAllowedUserIDs adds the "allowed_users" edge to the User entity by ids. +func (m *GroupMutation) AddAllowedUserIDs(ids ...int64) { + if m.allowed_users == nil { + m.allowed_users = make(map[int64]struct{}) + } + for i := range ids { + m.allowed_users[ids[i]] = struct{}{} + } } -// SetExpiresAt sets the "expires_at" field. -func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { - m.expires_at = &t +// ClearAllowedUsers clears the "allowed_users" edge to the User entity. +func (m *GroupMutation) ClearAllowedUsers() { + m.clearedallowed_users = true } -// ExpiresAt returns the value of the "expires_at" field in the mutation. -func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { - v := m.expires_at - if v == nil { - return - } - return *v, true +// AllowedUsersCleared reports if the "allowed_users" edge to the User entity was cleared. +func (m *GroupMutation) AllowedUsersCleared() bool { + return m.clearedallowed_users } -// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. -// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") +// RemoveAllowedUserIDs removes the "allowed_users" edge to the User entity by IDs. +func (m *GroupMutation) RemoveAllowedUserIDs(ids ...int64) { + if m.removedallowed_users == nil { + m.removedallowed_users = make(map[int64]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldExpiresAt requires an ID field in the mutation") + for i := range ids { + delete(m.allowed_users, ids[i]) + m.removedallowed_users[ids[i]] = struct{}{} } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) +} + +// RemovedAllowedUsers returns the removed IDs of the "allowed_users" edge to the User entity. +func (m *GroupMutation) RemovedAllowedUsersIDs() (ids []int64) { + for id := range m.removedallowed_users { + ids = append(ids, id) } - return oldValue.ExpiresAt, nil + return } -// ResetExpiresAt resets all changes to the "expires_at" field. -func (m *IdempotencyRecordMutation) ResetExpiresAt() { - m.expires_at = nil +// AllowedUsersIDs returns the "allowed_users" edge IDs in the mutation. +func (m *GroupMutation) AllowedUsersIDs() (ids []int64) { + for id := range m.allowed_users { + ids = append(ids, id) + } + return } -// Where appends a list predicates to the IdempotencyRecordMutation builder. -func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { +// ResetAllowedUsers resets all changes to the "allowed_users" edge. +func (m *GroupMutation) ResetAllowedUsers() { + m.allowed_users = nil + m.clearedallowed_users = false + m.removedallowed_users = nil +} + +// Where appends a list predicates to the GroupMutation builder. +func (m *GroupMutation) Where(ps ...predicate.Group) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// WhereP appends storage-level predicates to the GroupMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.IdempotencyRecord, len(ps)) +func (m *GroupMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Group, len(ps)) for i := range ps { p[i] = ps[i] } @@ -11816,215 +12030,511 @@ func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *IdempotencyRecordMutation) Op() Op { +func (m *GroupMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *IdempotencyRecordMutation) SetOp(op Op) { +func (m *GroupMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (IdempotencyRecord). -func (m *IdempotencyRecordMutation) Type() string { +// Type returns the node type of this mutation (Group). +func (m *GroupMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *IdempotencyRecordMutation) Fields() []string { - fields := make([]string, 0, 11) +func (m *GroupMutation) Fields() []string { + fields := make([]string, 0, 30) if m.created_at != nil { - fields = append(fields, idempotencyrecord.FieldCreatedAt) + fields = append(fields, group.FieldCreatedAt) } if m.updated_at != nil { - fields = append(fields, idempotencyrecord.FieldUpdatedAt) + fields = append(fields, group.FieldUpdatedAt) } - if m.scope != nil { - fields = append(fields, idempotencyrecord.FieldScope) + if m.deleted_at != nil { + fields = append(fields, group.FieldDeletedAt) } - if m.idempotency_key_hash != nil { - fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + if m.name != nil { + fields = append(fields, group.FieldName) } - if m.request_fingerprint != nil { - fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + if m.description != nil { + fields = append(fields, group.FieldDescription) } - if m.status != nil { - fields = append(fields, idempotencyrecord.FieldStatus) + if m.rate_multiplier != nil { + fields = append(fields, group.FieldRateMultiplier) } - if m.response_status != nil { - fields = append(fields, idempotencyrecord.FieldResponseStatus) + if m.is_exclusive != nil { + fields = append(fields, group.FieldIsExclusive) } - if m.response_body != nil { - fields = append(fields, idempotencyrecord.FieldResponseBody) + if m.status != nil { + fields = append(fields, group.FieldStatus) } - if m.error_reason != nil { - fields = append(fields, idempotencyrecord.FieldErrorReason) + if m.platform != nil { + fields = append(fields, group.FieldPlatform) } - if m.locked_until != nil { - fields = append(fields, idempotencyrecord.FieldLockedUntil) + if m.subscription_type != nil { + fields = append(fields, group.FieldSubscriptionType) } - if m.expires_at != nil { - fields = append(fields, idempotencyrecord.FieldExpiresAt) + if m.daily_limit_usd != nil { + fields = append(fields, group.FieldDailyLimitUsd) } - return fields -} - -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { - switch name { - case idempotencyrecord.FieldCreatedAt: - return m.CreatedAt() - case idempotencyrecord.FieldUpdatedAt: - return m.UpdatedAt() - case idempotencyrecord.FieldScope: - return m.Scope() - case idempotencyrecord.FieldIdempotencyKeyHash: - return m.IdempotencyKeyHash() - case idempotencyrecord.FieldRequestFingerprint: - return m.RequestFingerprint() - case idempotencyrecord.FieldStatus: - return m.Status() - case idempotencyrecord.FieldResponseStatus: - return m.ResponseStatus() - case idempotencyrecord.FieldResponseBody: - return m.ResponseBody() - case idempotencyrecord.FieldErrorReason: - return m.ErrorReason() - case idempotencyrecord.FieldLockedUntil: - return m.LockedUntil() - case idempotencyrecord.FieldExpiresAt: - return m.ExpiresAt() + if m.weekly_limit_usd != nil { + fields = append(fields, group.FieldWeeklyLimitUsd) } - return nil, false + if m.monthly_limit_usd != nil { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.default_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } + if m.image_price_1k != nil { + fields = append(fields, group.FieldImagePrice1k) + } + if m.image_price_2k != nil { + fields = append(fields, group.FieldImagePrice2k) + } + if m.image_price_4k != nil { + fields = append(fields, group.FieldImagePrice4k) + } + if m.claude_code_only != nil { + fields = append(fields, group.FieldClaudeCodeOnly) + } + if m.fallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.fallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.model_routing != nil { + fields = append(fields, group.FieldModelRouting) + } + if m.model_routing_enabled != nil { + fields = append(fields, group.FieldModelRoutingEnabled) + } + if m.mcp_xml_inject != nil { + fields = append(fields, group.FieldMcpXMLInject) + } + if m.supported_model_scopes != nil { + fields = append(fields, group.FieldSupportedModelScopes) + } + if m.sort_order != nil { + fields = append(fields, group.FieldSortOrder) + } + if m.allow_messages_dispatch != nil { + fields = append(fields, group.FieldAllowMessagesDispatch) + } + if m.require_oauth_only != nil { + fields = append(fields, group.FieldRequireOauthOnly) + } + if m.require_privacy_set != nil { + fields = append(fields, group.FieldRequirePrivacySet) + } + if m.default_mapped_model != nil { + fields = append(fields, group.FieldDefaultMappedModel) + } + if m.messages_dispatch_model_config != nil { + fields = append(fields, group.FieldMessagesDispatchModelConfig) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *GroupMutation) Field(name string) (ent.Value, bool) { + switch name { + case group.FieldCreatedAt: + return m.CreatedAt() + case group.FieldUpdatedAt: + return m.UpdatedAt() + case group.FieldDeletedAt: + return m.DeletedAt() + case group.FieldName: + return m.Name() + case group.FieldDescription: + return m.Description() + case group.FieldRateMultiplier: + return m.RateMultiplier() + case group.FieldIsExclusive: + return m.IsExclusive() + case group.FieldStatus: + return m.Status() + case group.FieldPlatform: + return m.Platform() + case group.FieldSubscriptionType: + return m.SubscriptionType() + case group.FieldDailyLimitUsd: + return m.DailyLimitUsd() + case group.FieldWeeklyLimitUsd: + return m.WeeklyLimitUsd() + case group.FieldMonthlyLimitUsd: + return m.MonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.DefaultValidityDays() + case group.FieldImagePrice1k: + return m.ImagePrice1k() + case group.FieldImagePrice2k: + return m.ImagePrice2k() + case group.FieldImagePrice4k: + return m.ImagePrice4k() + case group.FieldClaudeCodeOnly: + return m.ClaudeCodeOnly() + case group.FieldFallbackGroupID: + return m.FallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.FallbackGroupIDOnInvalidRequest() + case group.FieldModelRouting: + return m.ModelRouting() + case group.FieldModelRoutingEnabled: + return m.ModelRoutingEnabled() + case group.FieldMcpXMLInject: + return m.McpXMLInject() + case group.FieldSupportedModelScopes: + return m.SupportedModelScopes() + case group.FieldSortOrder: + return m.SortOrder() + case group.FieldAllowMessagesDispatch: + return m.AllowMessagesDispatch() + case group.FieldRequireOauthOnly: + return m.RequireOauthOnly() + case group.FieldRequirePrivacySet: + return m.RequirePrivacySet() + case group.FieldDefaultMappedModel: + return m.DefaultMappedModel() + case group.FieldMessagesDispatchModelConfig: + return m.MessagesDispatchModelConfig() + } + return nil, false } // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case idempotencyrecord.FieldCreatedAt: + case group.FieldCreatedAt: return m.OldCreatedAt(ctx) - case idempotencyrecord.FieldUpdatedAt: + case group.FieldUpdatedAt: return m.OldUpdatedAt(ctx) - case idempotencyrecord.FieldScope: - return m.OldScope(ctx) - case idempotencyrecord.FieldIdempotencyKeyHash: - return m.OldIdempotencyKeyHash(ctx) - case idempotencyrecord.FieldRequestFingerprint: - return m.OldRequestFingerprint(ctx) - case idempotencyrecord.FieldStatus: + case group.FieldDeletedAt: + return m.OldDeletedAt(ctx) + case group.FieldName: + return m.OldName(ctx) + case group.FieldDescription: + return m.OldDescription(ctx) + case group.FieldRateMultiplier: + return m.OldRateMultiplier(ctx) + case group.FieldIsExclusive: + return m.OldIsExclusive(ctx) + case group.FieldStatus: return m.OldStatus(ctx) - case idempotencyrecord.FieldResponseStatus: - return m.OldResponseStatus(ctx) - case idempotencyrecord.FieldResponseBody: - return m.OldResponseBody(ctx) - case idempotencyrecord.FieldErrorReason: - return m.OldErrorReason(ctx) - case idempotencyrecord.FieldLockedUntil: - return m.OldLockedUntil(ctx) - case idempotencyrecord.FieldExpiresAt: - return m.OldExpiresAt(ctx) + case group.FieldPlatform: + return m.OldPlatform(ctx) + case group.FieldSubscriptionType: + return m.OldSubscriptionType(ctx) + case group.FieldDailyLimitUsd: + return m.OldDailyLimitUsd(ctx) + case group.FieldWeeklyLimitUsd: + return m.OldWeeklyLimitUsd(ctx) + case group.FieldMonthlyLimitUsd: + return m.OldMonthlyLimitUsd(ctx) + case group.FieldDefaultValidityDays: + return m.OldDefaultValidityDays(ctx) + case group.FieldImagePrice1k: + return m.OldImagePrice1k(ctx) + case group.FieldImagePrice2k: + return m.OldImagePrice2k(ctx) + case group.FieldImagePrice4k: + return m.OldImagePrice4k(ctx) + case group.FieldClaudeCodeOnly: + return m.OldClaudeCodeOnly(ctx) + case group.FieldFallbackGroupID: + return m.OldFallbackGroupID(ctx) + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.OldFallbackGroupIDOnInvalidRequest(ctx) + case group.FieldModelRouting: + return m.OldModelRouting(ctx) + case group.FieldModelRoutingEnabled: + return m.OldModelRoutingEnabled(ctx) + case group.FieldMcpXMLInject: + return m.OldMcpXMLInject(ctx) + case group.FieldSupportedModelScopes: + return m.OldSupportedModelScopes(ctx) + case group.FieldSortOrder: + return m.OldSortOrder(ctx) + case group.FieldAllowMessagesDispatch: + return m.OldAllowMessagesDispatch(ctx) + case group.FieldRequireOauthOnly: + return m.OldRequireOauthOnly(ctx) + case group.FieldRequirePrivacySet: + return m.OldRequirePrivacySet(ctx) + case group.FieldDefaultMappedModel: + return m.OldDefaultMappedModel(ctx) + case group.FieldMessagesDispatchModelConfig: + return m.OldMessagesDispatchModelConfig(ctx) } - return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) + return nil, fmt.Errorf("unknown Group field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { +func (m *GroupMutation) SetField(name string, value ent.Value) error { switch name { - case idempotencyrecord.FieldCreatedAt: + case group.FieldCreatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetCreatedAt(v) return nil - case idempotencyrecord.FieldUpdatedAt: + case group.FieldUpdatedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } m.SetUpdatedAt(v) return nil - case idempotencyrecord.FieldScope: - v, ok := value.(string) + case group.FieldDeletedAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetScope(v) + m.SetDeletedAt(v) return nil - case idempotencyrecord.FieldIdempotencyKeyHash: + case group.FieldName: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetIdempotencyKeyHash(v) + m.SetName(v) return nil - case idempotencyrecord.FieldRequestFingerprint: + case group.FieldDescription: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRequestFingerprint(v) + m.SetDescription(v) return nil - case idempotencyrecord.FieldStatus: - v, ok := value.(string) + case group.FieldRateMultiplier: + v, ok := value.(float64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetStatus(v) + m.SetRateMultiplier(v) return nil - case idempotencyrecord.FieldResponseStatus: - v, ok := value.(int) + case group.FieldIsExclusive: + v, ok := value.(bool) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseStatus(v) + m.SetIsExclusive(v) return nil - case idempotencyrecord.FieldResponseBody: + case group.FieldStatus: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetResponseBody(v) + m.SetStatus(v) return nil - case idempotencyrecord.FieldErrorReason: + case group.FieldPlatform: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetErrorReason(v) + m.SetPlatform(v) return nil - case idempotencyrecord.FieldLockedUntil: - v, ok := value.(time.Time) + case group.FieldSubscriptionType: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetLockedUntil(v) + m.SetSubscriptionType(v) return nil - case idempotencyrecord.FieldExpiresAt: - v, ok := value.(time.Time) + case group.FieldDailyLimitUsd: + v, ok := value.(float64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetExpiresAt(v) + m.SetDailyLimitUsd(v) + return nil + case group.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWeeklyLimitUsd(v) + return nil + case group.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMonthlyLimitUsd(v) + return nil + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultValidityDays(v) + return nil + case group.FieldImagePrice1k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice1k(v) + return nil + case group.FieldImagePrice2k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice2k(v) + return nil + case group.FieldImagePrice4k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImagePrice4k(v) + return nil + case group.FieldClaudeCodeOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaudeCodeOnly(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupID(v) + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupIDOnInvalidRequest(v) + return nil + case group.FieldModelRouting: + v, ok := value.(map[string][]int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRouting(v) + return nil + case group.FieldModelRoutingEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetModelRoutingEnabled(v) + return nil + case group.FieldMcpXMLInject: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMcpXMLInject(v) + return nil + case group.FieldSupportedModelScopes: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedModelScopes(v) + return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSortOrder(v) + return nil + case group.FieldAllowMessagesDispatch: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowMessagesDispatch(v) + return nil + case group.FieldRequireOauthOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequireOauthOnly(v) + return nil + case group.FieldRequirePrivacySet: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequirePrivacySet(v) + return nil + case group.FieldDefaultMappedModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDefaultMappedModel(v) + return nil + case group.FieldMessagesDispatchModelConfig: + v, ok := value.(domain.OpenAIMessagesDispatchModelConfig) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMessagesDispatchModelConfig(v) return nil } - return fmt.Errorf("unknown IdempotencyRecord field %s", name) + return fmt.Errorf("unknown Group field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *IdempotencyRecordMutation) AddedFields() []string { +func (m *GroupMutation) AddedFields() []string { var fields []string - if m.addresponse_status != nil { - fields = append(fields, idempotencyrecord.FieldResponseStatus) + if m.addrate_multiplier != nil { + fields = append(fields, group.FieldRateMultiplier) + } + if m.adddaily_limit_usd != nil { + fields = append(fields, group.FieldDailyLimitUsd) + } + if m.addweekly_limit_usd != nil { + fields = append(fields, group.FieldWeeklyLimitUsd) + } + if m.addmonthly_limit_usd != nil { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.adddefault_validity_days != nil { + fields = append(fields, group.FieldDefaultValidityDays) + } + if m.addimage_price_1k != nil { + fields = append(fields, group.FieldImagePrice1k) + } + if m.addimage_price_2k != nil { + fields = append(fields, group.FieldImagePrice2k) + } + if m.addimage_price_4k != nil { + fields = append(fields, group.FieldImagePrice4k) + } + if m.addfallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.addfallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.addsort_order != nil { + fields = append(fields, group.FieldSortOrder) } return fields } @@ -12032,10 +12542,30 @@ func (m *IdempotencyRecordMutation) AddedFields() []string { // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { +func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { switch name { - case idempotencyrecord.FieldResponseStatus: - return m.AddedResponseStatus() + case group.FieldRateMultiplier: + return m.AddedRateMultiplier() + case group.FieldDailyLimitUsd: + return m.AddedDailyLimitUsd() + case group.FieldWeeklyLimitUsd: + return m.AddedWeeklyLimitUsd() + case group.FieldMonthlyLimitUsd: + return m.AddedMonthlyLimitUsd() + case group.FieldDefaultValidityDays: + return m.AddedDefaultValidityDays() + case group.FieldImagePrice1k: + return m.AddedImagePrice1k() + case group.FieldImagePrice2k: + return m.AddedImagePrice2k() + case group.FieldImagePrice4k: + return m.AddedImagePrice4k() + case group.FieldFallbackGroupID: + return m.AddedFallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.AddedFallbackGroupIDOnInvalidRequest() + case group.FieldSortOrder: + return m.AddedSortOrder() } return nil, false } @@ -12043,182 +12573,524 @@ func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { +func (m *GroupMutation) AddField(name string, value ent.Value) error { switch name { - case idempotencyrecord.FieldResponseStatus: - v, ok := value.(int) + case group.FieldRateMultiplier: + v, ok := value.(float64) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.AddResponseStatus(v) + m.AddRateMultiplier(v) return nil - } - return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) -} - -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *IdempotencyRecordMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { - fields = append(fields, idempotencyrecord.FieldResponseStatus) - } - if m.FieldCleared(idempotencyrecord.FieldResponseBody) { - fields = append(fields, idempotencyrecord.FieldResponseBody) - } - if m.FieldCleared(idempotencyrecord.FieldErrorReason) { - fields = append(fields, idempotencyrecord.FieldErrorReason) - } - if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { - fields = append(fields, idempotencyrecord.FieldLockedUntil) - } - return fields -} - -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] - return ok -} - -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *IdempotencyRecordMutation) ClearField(name string) error { - switch name { - case idempotencyrecord.FieldResponseStatus: - m.ClearResponseStatus() + case group.FieldDailyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDailyLimitUsd(v) return nil - case idempotencyrecord.FieldResponseBody: - m.ClearResponseBody() + case group.FieldWeeklyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddWeeklyLimitUsd(v) return nil - case idempotencyrecord.FieldErrorReason: - m.ClearErrorReason() + case group.FieldMonthlyLimitUsd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMonthlyLimitUsd(v) return nil - case idempotencyrecord.FieldLockedUntil: - m.ClearLockedUntil() + case group.FieldDefaultValidityDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDefaultValidityDays(v) + return nil + case group.FieldImagePrice1k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice1k(v) + return nil + case group.FieldImagePrice2k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice2k(v) + return nil + case group.FieldImagePrice4k: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddImagePrice4k(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupID(v) + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupIDOnInvalidRequest(v) + return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSortOrder(v) return nil } - return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) + return fmt.Errorf("unknown Group numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *GroupMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(group.FieldDeletedAt) { + fields = append(fields, group.FieldDeletedAt) + } + if m.FieldCleared(group.FieldDescription) { + fields = append(fields, group.FieldDescription) + } + if m.FieldCleared(group.FieldDailyLimitUsd) { + fields = append(fields, group.FieldDailyLimitUsd) + } + if m.FieldCleared(group.FieldWeeklyLimitUsd) { + fields = append(fields, group.FieldWeeklyLimitUsd) + } + if m.FieldCleared(group.FieldMonthlyLimitUsd) { + fields = append(fields, group.FieldMonthlyLimitUsd) + } + if m.FieldCleared(group.FieldImagePrice1k) { + fields = append(fields, group.FieldImagePrice1k) + } + if m.FieldCleared(group.FieldImagePrice2k) { + fields = append(fields, group.FieldImagePrice2k) + } + if m.FieldCleared(group.FieldImagePrice4k) { + fields = append(fields, group.FieldImagePrice4k) + } + if m.FieldCleared(group.FieldFallbackGroupID) { + fields = append(fields, group.FieldFallbackGroupID) + } + if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } + if m.FieldCleared(group.FieldModelRouting) { + fields = append(fields, group.FieldModelRouting) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *GroupMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *GroupMutation) ClearField(name string) error { + switch name { + case group.FieldDeletedAt: + m.ClearDeletedAt() + return nil + case group.FieldDescription: + m.ClearDescription() + return nil + case group.FieldDailyLimitUsd: + m.ClearDailyLimitUsd() + return nil + case group.FieldWeeklyLimitUsd: + m.ClearWeeklyLimitUsd() + return nil + case group.FieldMonthlyLimitUsd: + m.ClearMonthlyLimitUsd() + return nil + case group.FieldImagePrice1k: + m.ClearImagePrice1k() + return nil + case group.FieldImagePrice2k: + m.ClearImagePrice2k() + return nil + case group.FieldImagePrice4k: + m.ClearImagePrice4k() + return nil + case group.FieldFallbackGroupID: + m.ClearFallbackGroupID() + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ClearFallbackGroupIDOnInvalidRequest() + return nil + case group.FieldModelRouting: + m.ClearModelRouting() + return nil + } + return fmt.Errorf("unknown Group nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *IdempotencyRecordMutation) ResetField(name string) error { +func (m *GroupMutation) ResetField(name string) error { switch name { - case idempotencyrecord.FieldCreatedAt: + case group.FieldCreatedAt: m.ResetCreatedAt() return nil - case idempotencyrecord.FieldUpdatedAt: + case group.FieldUpdatedAt: m.ResetUpdatedAt() return nil - case idempotencyrecord.FieldScope: - m.ResetScope() + case group.FieldDeletedAt: + m.ResetDeletedAt() return nil - case idempotencyrecord.FieldIdempotencyKeyHash: - m.ResetIdempotencyKeyHash() + case group.FieldName: + m.ResetName() return nil - case idempotencyrecord.FieldRequestFingerprint: - m.ResetRequestFingerprint() + case group.FieldDescription: + m.ResetDescription() return nil - case idempotencyrecord.FieldStatus: + case group.FieldRateMultiplier: + m.ResetRateMultiplier() + return nil + case group.FieldIsExclusive: + m.ResetIsExclusive() + return nil + case group.FieldStatus: m.ResetStatus() return nil - case idempotencyrecord.FieldResponseStatus: - m.ResetResponseStatus() + case group.FieldPlatform: + m.ResetPlatform() return nil - case idempotencyrecord.FieldResponseBody: - m.ResetResponseBody() + case group.FieldSubscriptionType: + m.ResetSubscriptionType() return nil - case idempotencyrecord.FieldErrorReason: - m.ResetErrorReason() + case group.FieldDailyLimitUsd: + m.ResetDailyLimitUsd() return nil - case idempotencyrecord.FieldLockedUntil: - m.ResetLockedUntil() + case group.FieldWeeklyLimitUsd: + m.ResetWeeklyLimitUsd() return nil - case idempotencyrecord.FieldExpiresAt: - m.ResetExpiresAt() + case group.FieldMonthlyLimitUsd: + m.ResetMonthlyLimitUsd() + return nil + case group.FieldDefaultValidityDays: + m.ResetDefaultValidityDays() + return nil + case group.FieldImagePrice1k: + m.ResetImagePrice1k() + return nil + case group.FieldImagePrice2k: + m.ResetImagePrice2k() + return nil + case group.FieldImagePrice4k: + m.ResetImagePrice4k() + return nil + case group.FieldClaudeCodeOnly: + m.ResetClaudeCodeOnly() + return nil + case group.FieldFallbackGroupID: + m.ResetFallbackGroupID() + return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ResetFallbackGroupIDOnInvalidRequest() + return nil + case group.FieldModelRouting: + m.ResetModelRouting() + return nil + case group.FieldModelRoutingEnabled: + m.ResetModelRoutingEnabled() + return nil + case group.FieldMcpXMLInject: + m.ResetMcpXMLInject() + return nil + case group.FieldSupportedModelScopes: + m.ResetSupportedModelScopes() + return nil + case group.FieldSortOrder: + m.ResetSortOrder() + return nil + case group.FieldAllowMessagesDispatch: + m.ResetAllowMessagesDispatch() + return nil + case group.FieldRequireOauthOnly: + m.ResetRequireOauthOnly() + return nil + case group.FieldRequirePrivacySet: + m.ResetRequirePrivacySet() + return nil + case group.FieldDefaultMappedModel: + m.ResetDefaultMappedModel() + return nil + case group.FieldMessagesDispatchModelConfig: + m.ResetMessagesDispatchModelConfig() return nil } - return fmt.Errorf("unknown IdempotencyRecord field %s", name) + return fmt.Errorf("unknown Group field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *IdempotencyRecordMutation) AddedEdges() []string { - edges := make([]string, 0, 0) - return edges -} - -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { - return nil -} +func (m *GroupMutation) AddedEdges() []string { + edges := make([]string, 0, 6) + if m.api_keys != nil { + edges = append(edges, group.EdgeAPIKeys) + } + if m.redeem_codes != nil { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.subscriptions != nil { + edges = append(edges, group.EdgeSubscriptions) + } + if m.usage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } + if m.accounts != nil { + edges = append(edges, group.EdgeAccounts) + } + if m.allowed_users != nil { + edges = append(edges, group.EdgeAllowedUsers) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *GroupMutation) AddedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.api_keys)) + for id := range m.api_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.redeem_codes)) + for id := range m.redeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.subscriptions)) + for id := range m.subscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.usage_logs)) + for id := range m.usage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.accounts)) + for id := range m.accounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.allowed_users)) + for id := range m.allowed_users { + ids = append(ids, id) + } + return ids + } + return nil +} // RemovedEdges returns all edge names that were removed in this mutation. -func (m *IdempotencyRecordMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *GroupMutation) RemovedEdges() []string { + edges := make([]string, 0, 6) + if m.removedapi_keys != nil { + edges = append(edges, group.EdgeAPIKeys) + } + if m.removedredeem_codes != nil { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.removedsubscriptions != nil { + edges = append(edges, group.EdgeSubscriptions) + } + if m.removedusage_logs != nil { + edges = append(edges, group.EdgeUsageLogs) + } + if m.removedaccounts != nil { + edges = append(edges, group.EdgeAccounts) + } + if m.removedallowed_users != nil { + edges = append(edges, group.EdgeAllowedUsers) + } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { +func (m *GroupMutation) RemovedIDs(name string) []ent.Value { + switch name { + case group.EdgeAPIKeys: + ids := make([]ent.Value, 0, len(m.removedapi_keys)) + for id := range m.removedapi_keys { + ids = append(ids, id) + } + return ids + case group.EdgeRedeemCodes: + ids := make([]ent.Value, 0, len(m.removedredeem_codes)) + for id := range m.removedredeem_codes { + ids = append(ids, id) + } + return ids + case group.EdgeSubscriptions: + ids := make([]ent.Value, 0, len(m.removedsubscriptions)) + for id := range m.removedsubscriptions { + ids = append(ids, id) + } + return ids + case group.EdgeUsageLogs: + ids := make([]ent.Value, 0, len(m.removedusage_logs)) + for id := range m.removedusage_logs { + ids = append(ids, id) + } + return ids + case group.EdgeAccounts: + ids := make([]ent.Value, 0, len(m.removedaccounts)) + for id := range m.removedaccounts { + ids = append(ids, id) + } + return ids + case group.EdgeAllowedUsers: + ids := make([]ent.Value, 0, len(m.removedallowed_users)) + for id := range m.removedallowed_users { + ids = append(ids, id) + } + return ids + } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *IdempotencyRecordMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *GroupMutation) ClearedEdges() []string { + edges := make([]string, 0, 6) + if m.clearedapi_keys { + edges = append(edges, group.EdgeAPIKeys) + } + if m.clearedredeem_codes { + edges = append(edges, group.EdgeRedeemCodes) + } + if m.clearedsubscriptions { + edges = append(edges, group.EdgeSubscriptions) + } + if m.clearedusage_logs { + edges = append(edges, group.EdgeUsageLogs) + } + if m.clearedaccounts { + edges = append(edges, group.EdgeAccounts) + } + if m.clearedallowed_users { + edges = append(edges, group.EdgeAllowedUsers) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { +func (m *GroupMutation) EdgeCleared(name string) bool { + switch name { + case group.EdgeAPIKeys: + return m.clearedapi_keys + case group.EdgeRedeemCodes: + return m.clearedredeem_codes + case group.EdgeSubscriptions: + return m.clearedsubscriptions + case group.EdgeUsageLogs: + return m.clearedusage_logs + case group.EdgeAccounts: + return m.clearedaccounts + case group.EdgeAllowedUsers: + return m.clearedallowed_users + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *IdempotencyRecordMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +func (m *GroupMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown Group unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *IdempotencyRecordMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +func (m *GroupMutation) ResetEdge(name string) error { + switch name { + case group.EdgeAPIKeys: + m.ResetAPIKeys() + return nil + case group.EdgeRedeemCodes: + m.ResetRedeemCodes() + return nil + case group.EdgeSubscriptions: + m.ResetSubscriptions() + return nil + case group.EdgeUsageLogs: + m.ResetUsageLogs() + return nil + case group.EdgeAccounts: + m.ResetAccounts() + return nil + case group.EdgeAllowedUsers: + m.ResetAllowedUsers() + return nil + } + return fmt.Errorf("unknown Group edge %s", name) } -// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. -type PaymentAuditLogMutation struct { +// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. +type IdempotencyRecordMutation struct { config - op Op - typ string - id *int64 - order_id *string - action *string - detail *string - operator *string - created_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentAuditLog, error) - predicates []predicate.PaymentAuditLog + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + scope *string + idempotency_key_hash *string + request_fingerprint *string + status *string + response_status *int + addresponse_status *int + response_body *string + error_reason *string + locked_until *time.Time + expires_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*IdempotencyRecord, error) + predicates []predicate.IdempotencyRecord } -var _ ent.Mutation = (*PaymentAuditLogMutation)(nil) +var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) -// paymentauditlogOption allows management of the mutation configuration using functional options. -type paymentauditlogOption func(*PaymentAuditLogMutation) +// idempotencyrecordOption allows management of the mutation configuration using functional options. +type idempotencyrecordOption func(*IdempotencyRecordMutation) -// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity. -func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation { - m := &PaymentAuditLogMutation{ +// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. +func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { + m := &IdempotencyRecordMutation{ config: c, op: op, - typ: TypePaymentAuditLog, + typ: TypeIdempotencyRecord, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -12227,20 +13099,20 @@ func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) return m } -// withPaymentAuditLogID sets the ID field of the mutation. -func withPaymentAuditLogID(id int64) paymentauditlogOption { - return func(m *PaymentAuditLogMutation) { +// withIdempotencyRecordID sets the ID field of the mutation. +func withIdempotencyRecordID(id int64) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { var ( err error once sync.Once - value *PaymentAuditLog + value *IdempotencyRecord ) - m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) { + m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().PaymentAuditLog.Get(ctx, id) + value, err = m.Client().IdempotencyRecord.Get(ctx, id) } }) return value, err @@ -12249,10 +13121,10 @@ func withPaymentAuditLogID(id int64) paymentauditlogOption { } } -// withPaymentAuditLog sets the old PaymentAuditLog of the mutation. -func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { - return func(m *PaymentAuditLogMutation) { - m.oldValue = func(context.Context) (*PaymentAuditLog, error) { +// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. +func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + m.oldValue = func(context.Context) (*IdempotencyRecord, error) { return node, nil } m.id = &node.ID @@ -12261,7 +13133,7 @@ func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentAuditLogMutation) Client() *Client { +func (m IdempotencyRecordMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -12269,7 +13141,7 @@ func (m PaymentAuditLogMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m PaymentAuditLogMutation) Tx() (*Tx, error) { +func (m IdempotencyRecordMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -12280,7 +13152,7 @@ func (m PaymentAuditLogMutation) Tx() (*Tx, error) { // ID returns the ID value in the mutation. Note that the ID is only available // if it was provided to the builder or after it was returned from the database. -func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { +func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { if m.id == nil { return } @@ -12291,7 +13163,7 @@ func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { // That means, if the mutation is applied within a transaction with an isolation level such // as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated // or updated by the mutation. -func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { +func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { switch { case m.op.Is(OpUpdateOne | OpDeleteOne): id, exists := m.ID() @@ -12300,3381 +13172,6098 @@ func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { } fallthrough case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx) + return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) default: return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } } -// SetOrderID sets the "order_id" field. -func (m *PaymentAuditLogMutation) SetOrderID(s string) { - m.order_id = &s +// SetCreatedAt sets the "created_at" field. +func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// OrderID returns the value of the "order_id" field in the mutation. -func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) { - v := m.order_id +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOrderID is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOrderID requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOrderID: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.OrderID, nil + return oldValue.CreatedAt, nil } -// ResetOrderID resets all changes to the "order_id" field. -func (m *PaymentAuditLogMutation) ResetOrderID() { - m.order_id = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdempotencyRecordMutation) ResetCreatedAt() { + m.created_at = nil } -// SetAction sets the "action" field. -func (m *PaymentAuditLogMutation) SetAction(s string) { - m.action = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// Action returns the value of the "action" field in the mutation. -func (m *PaymentAuditLogMutation) Action() (r string, exists bool) { - v := m.action +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldAction returns the old "action" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAction is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAction requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAction: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.Action, nil + return oldValue.UpdatedAt, nil } -// ResetAction resets all changes to the "action" field. -func (m *PaymentAuditLogMutation) ResetAction() { - m.action = nil +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdempotencyRecordMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetDetail sets the "detail" field. -func (m *PaymentAuditLogMutation) SetDetail(s string) { - m.detail = &s +// SetScope sets the "scope" field. +func (m *IdempotencyRecordMutation) SetScope(s string) { + m.scope = &s } -// Detail returns the value of the "detail" field in the mutation. -func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) { - v := m.detail +// Scope returns the value of the "scope" field in the mutation. +func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { + v := m.scope if v == nil { return } return *v, true } -// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDetail is only allowed on UpdateOne operations") + return v, errors.New("OldScope is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDetail requires an ID field in the mutation") + return v, errors.New("OldScope requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldDetail: %w", err) + return v, fmt.Errorf("querying old value for OldScope: %w", err) } - return oldValue.Detail, nil + return oldValue.Scope, nil } -// ResetDetail resets all changes to the "detail" field. -func (m *PaymentAuditLogMutation) ResetDetail() { - m.detail = nil +// ResetScope resets all changes to the "scope" field. +func (m *IdempotencyRecordMutation) ResetScope() { + m.scope = nil } -// SetOperator sets the "operator" field. -func (m *PaymentAuditLogMutation) SetOperator(s string) { - m.operator = &s +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { + m.idempotency_key_hash = &s } -// Operator returns the value of the "operator" field in the mutation. -func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) { - v := m.operator +// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. +func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { + v := m.idempotency_key_hash if v == nil { return } return *v, true } -// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOperator is only allowed on UpdateOne operations") + return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOperator requires an ID field in the mutation") + return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOperator: %w", err) + return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) } - return oldValue.Operator, nil + return oldValue.IdempotencyKeyHash, nil } -// ResetOperator resets all changes to the "operator" field. -func (m *PaymentAuditLogMutation) ResetOperator() { - m.operator = nil +// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { + m.idempotency_key_hash = nil } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetRequestFingerprint sets the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { + m.request_fingerprint = &s } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. +func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { + v := m.request_fingerprint if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity. -// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. +// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.RequestFingerprint, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentAuditLogMutation) ResetCreatedAt() { - m.created_at = nil +// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { + m.request_fingerprint = nil } -// Where appends a list predicates to the PaymentAuditLogMutation builder. -func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) { - m.predicates = append(m.predicates, ps...) +// SetStatus sets the "status" field. +func (m *IdempotencyRecordMutation) SetStatus(s string) { + m.status = &s } -// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentAuditLog, len(ps)) - for i := range ps { - p[i] = ps[i] +// Status returns the value of the "status" field in the mutation. +func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return } - m.Where(p...) + return *v, true } -// Op returns the operation name. -func (m *PaymentAuditLogMutation) Op() Op { - return m.op +// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil } -// SetOp allows setting the mutation operation. -func (m *PaymentAuditLogMutation) SetOp(op Op) { - m.op = op +// ResetStatus resets all changes to the "status" field. +func (m *IdempotencyRecordMutation) ResetStatus() { + m.status = nil } -// Type returns the node type of this mutation (PaymentAuditLog). -func (m *PaymentAuditLogMutation) Type() string { - return m.typ +// SetResponseStatus sets the "response_status" field. +func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { + m.response_status = &i + m.addresponse_status = nil } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *PaymentAuditLogMutation) Fields() []string { - fields := make([]string, 0, 5) - if m.order_id != nil { - fields = append(fields, paymentauditlog.FieldOrderID) - } - if m.action != nil { - fields = append(fields, paymentauditlog.FieldAction) - } - if m.detail != nil { - fields = append(fields, paymentauditlog.FieldDetail) - } - if m.operator != nil { - fields = append(fields, paymentauditlog.FieldOperator) - } - if m.created_at != nil { - fields = append(fields, paymentauditlog.FieldCreatedAt) +// ResponseStatus returns the value of the "response_status" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { + v := m.response_status + if v == nil { + return } - return fields + return *v, true } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) { - switch name { - case paymentauditlog.FieldOrderID: - return m.OrderID() - case paymentauditlog.FieldAction: - return m.Action() - case paymentauditlog.FieldDetail: - return m.Detail() - case paymentauditlog.FieldOperator: - return m.Operator() - case paymentauditlog.FieldCreatedAt: - return m.CreatedAt() +// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") } - return nil, false -} - -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case paymentauditlog.FieldOrderID: - return m.OldOrderID(ctx) - case paymentauditlog.FieldAction: - return m.OldAction(ctx) - case paymentauditlog.FieldDetail: - return m.OldDetail(ctx) - case paymentauditlog.FieldOperator: - return m.OldOperator(ctx) - case paymentauditlog.FieldCreatedAt: - return m.OldCreatedAt(ctx) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseStatus requires an ID field in the mutation") } - return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + } + return oldValue.ResponseStatus, nil } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error { - switch name { - case paymentauditlog.FieldOrderID: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOrderID(v) - return nil - case paymentauditlog.FieldAction: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAction(v) - return nil - case paymentauditlog.FieldDetail: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDetail(v) - return nil - case paymentauditlog.FieldOperator: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOperator(v) - return nil - case paymentauditlog.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreatedAt(v) - return nil +// AddResponseStatus adds i to the "response_status" field. +func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { + if m.addresponse_status != nil { + *m.addresponse_status += i + } else { + m.addresponse_status = &i } - return fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *PaymentAuditLogMutation) AddedFields() []string { - return nil +// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. +func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { + v := m.addresponse_status + if v == nil { + return + } + return *v, true } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) { - return nil, false +// ClearResponseStatus clears the value of the "response_status" field. +func (m *IdempotencyRecordMutation) ClearResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error { - switch name { - } - return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name) +// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] + return ok } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *PaymentAuditLogMutation) ClearedFields() []string { - return nil +// ResetResponseStatus resets all changes to the "response_status" field. +func (m *IdempotencyRecordMutation) ResetResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *PaymentAuditLogMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] - return ok +// SetResponseBody sets the "response_body" field. +func (m *IdempotencyRecordMutation) SetResponseBody(s string) { + m.response_body = &s } -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *PaymentAuditLogMutation) ClearField(name string) error { - return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name) +// ResponseBody returns the value of the "response_body" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { + v := m.response_body + if v == nil { + return + } + return *v, true } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *PaymentAuditLogMutation) ResetField(name string) error { - switch name { - case paymentauditlog.FieldOrderID: - m.ResetOrderID() - return nil - case paymentauditlog.FieldAction: - m.ResetAction() - return nil - case paymentauditlog.FieldDetail: - m.ResetDetail() - return nil - case paymentauditlog.FieldOperator: - m.ResetOperator() - return nil - case paymentauditlog.FieldCreatedAt: - m.ResetCreatedAt() - return nil +// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") } - return fmt.Errorf("unknown PaymentAuditLog field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + } + return oldValue.ResponseBody, nil } -// AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentAuditLogMutation) AddedEdges() []string { - edges := make([]string, 0, 0) - return edges +// ClearResponseBody clears the value of the "response_body" field. +func (m *IdempotencyRecordMutation) ClearResponseBody() { + m.response_body = nil + m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} } -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value { - return nil +// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] + return ok } -// RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentAuditLogMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) - return edges +// ResetResponseBody resets all changes to the "response_body" field. +func (m *IdempotencyRecordMutation) ResetResponseBody() { + m.response_body = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseBody) } -// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with -// the given name in this mutation. -func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value { - return nil +// SetErrorReason sets the "error_reason" field. +func (m *IdempotencyRecordMutation) SetErrorReason(s string) { + m.error_reason = &s } -// ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentAuditLogMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) - return edges +// ErrorReason returns the value of the "error_reason" field in the mutation. +func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { + v := m.error_reason + if v == nil { + return + } + return *v, true } -// EdgeCleared returns a boolean which indicates if the edge with the given name -// was cleared in this mutation. -func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool { - return false +// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + } + return oldValue.ErrorReason, nil } -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *PaymentAuditLogMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name) +// ClearErrorReason clears the value of the "error_reason" field. +func (m *IdempotencyRecordMutation) ClearErrorReason() { + m.error_reason = nil + m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} } -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *PaymentAuditLogMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown PaymentAuditLog edge %s", name) -} - -// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph. -type PaymentOrderMutation struct { - config - op Op - typ string - id *int64 - user_email *string - user_name *string - user_notes *string - amount *float64 - addamount *float64 - pay_amount *float64 - addpay_amount *float64 - fee_rate *float64 - addfee_rate *float64 - recharge_code *string - out_trade_no *string - payment_type *string - payment_trade_no *string - pay_url *string - qr_code *string - qr_code_img *string - order_type *string - plan_id *int64 - addplan_id *int64 - subscription_group_id *int64 - addsubscription_group_id *int64 - subscription_days *int - addsubscription_days *int - provider_instance_id *string - status *string - refund_amount *float64 - addrefund_amount *float64 - refund_reason *string - refund_at *time.Time - force_refund *bool - refund_requested_at *time.Time - refund_request_reason *string - refund_requested_by *string - expires_at *time.Time - paid_at *time.Time - completed_at *time.Time - failed_at *time.Time - failed_reason *string - client_ip *string - src_host *string - src_url *string - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - user *int64 - cleareduser bool - done bool - oldValue func(context.Context) (*PaymentOrder, error) - predicates []predicate.PaymentOrder -} - -var _ ent.Mutation = (*PaymentOrderMutation)(nil) - -// paymentorderOption allows management of the mutation configuration using functional options. -type paymentorderOption func(*PaymentOrderMutation) - -// newPaymentOrderMutation creates new mutation for the PaymentOrder entity. -func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation { - m := &PaymentOrderMutation{ - config: c, - op: op, - typ: TypePaymentOrder, - clearedFields: make(map[string]struct{}), - } - for _, opt := range opts { - opt(m) - } - return m -} - -// withPaymentOrderID sets the ID field of the mutation. -func withPaymentOrderID(id int64) paymentorderOption { - return func(m *PaymentOrderMutation) { - var ( - err error - once sync.Once - value *PaymentOrder - ) - m.oldValue = func(ctx context.Context) (*PaymentOrder, error) { - once.Do(func() { - if m.done { - err = errors.New("querying old values post mutation is not allowed") - } else { - value, err = m.Client().PaymentOrder.Get(ctx, id) - } - }) - return value, err - } - m.id = &id - } -} - -// withPaymentOrder sets the old PaymentOrder of the mutation. -func withPaymentOrder(node *PaymentOrder) paymentorderOption { - return func(m *PaymentOrderMutation) { - m.oldValue = func(context.Context) (*PaymentOrder, error) { - return node, nil - } - m.id = &node.ID - } -} - -// Client returns a new `ent.Client` from the mutation. If the mutation was -// executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentOrderMutation) Client() *Client { - client := &Client{config: m.config} - client.init() - return client -} - -// Tx returns an `ent.Tx` for mutations that were executed in transactions; -// it returns an error otherwise. -func (m PaymentOrderMutation) Tx() (*Tx, error) { - if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") - } - tx := &Tx{config: m.config} - tx.init() - return tx, nil -} - -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *PaymentOrderMutation) ID() (id int64, exists bool) { - if m.id == nil { - return - } - return *m.id, true +// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] + return ok } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []int64{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) - } +// ResetErrorReason resets all changes to the "error_reason" field. +func (m *IdempotencyRecordMutation) ResetErrorReason() { + m.error_reason = nil + delete(m.clearedFields, idempotencyrecord.FieldErrorReason) } -// SetUserID sets the "user_id" field. -func (m *PaymentOrderMutation) SetUserID(i int64) { - m.user = &i +// SetLockedUntil sets the "locked_until" field. +func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { + m.locked_until = &t } -// UserID returns the value of the "user_id" field in the mutation. -func (m *PaymentOrderMutation) UserID() (r int64, exists bool) { - v := m.user +// LockedUntil returns the value of the "locked_until" field in the mutation. +func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { + v := m.locked_until if v == nil { return } return *v, true } -// OldUserID returns the old "user_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) { +func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserID is only allowed on UpdateOne operations") + return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserID requires an ID field in the mutation") + return v, errors.New("OldLockedUntil requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserID: %w", err) + return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) } - return oldValue.UserID, nil + return oldValue.LockedUntil, nil } -// ResetUserID resets all changes to the "user_id" field. -func (m *PaymentOrderMutation) ResetUserID() { - m.user = nil +// ClearLockedUntil clears the value of the "locked_until" field. +func (m *IdempotencyRecordMutation) ClearLockedUntil() { + m.locked_until = nil + m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} } -// SetUserEmail sets the "user_email" field. -func (m *PaymentOrderMutation) SetUserEmail(s string) { - m.user_email = &s +// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] + return ok } -// UserEmail returns the value of the "user_email" field in the mutation. -func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) { - v := m.user_email +// ResetLockedUntil resets all changes to the "locked_until" field. +func (m *IdempotencyRecordMutation) ResetLockedUntil() { + m.locked_until = nil + delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at if v == nil { return } return *v, true } -// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) { +func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserEmail is only allowed on UpdateOne operations") + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserEmail requires an ID field in the mutation") + return v, errors.New("OldExpiresAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUserEmail: %w", err) + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) } - return oldValue.UserEmail, nil + return oldValue.ExpiresAt, nil } -// ResetUserEmail resets all changes to the "user_email" field. -func (m *PaymentOrderMutation) ResetUserEmail() { - m.user_email = nil +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *IdempotencyRecordMutation) ResetExpiresAt() { + m.expires_at = nil } -// SetUserName sets the "user_name" field. -func (m *PaymentOrderMutation) SetUserName(s string) { - m.user_name = &s +// Where appends a list predicates to the IdempotencyRecordMutation builder. +func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { + m.predicates = append(m.predicates, ps...) } -// UserName returns the value of the "user_name" field in the mutation. -func (m *PaymentOrderMutation) UserName() (r string, exists bool) { - v := m.user_name - if v == nil { - return - } - return *v, true -} - -// OldUserName returns the old "user_name" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserName is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserName requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUserName: %w", err) +// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdempotencyRecord, len(ps)) + for i := range ps { + p[i] = ps[i] } - return oldValue.UserName, nil + m.Where(p...) } -// ResetUserName resets all changes to the "user_name" field. -func (m *PaymentOrderMutation) ResetUserName() { - m.user_name = nil +// Op returns the operation name. +func (m *IdempotencyRecordMutation) Op() Op { + return m.op } -// SetUserNotes sets the "user_notes" field. -func (m *PaymentOrderMutation) SetUserNotes(s string) { - m.user_notes = &s +// SetOp allows setting the mutation operation. +func (m *IdempotencyRecordMutation) SetOp(op Op) { + m.op = op } -// UserNotes returns the value of the "user_notes" field in the mutation. -func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) { - v := m.user_notes - if v == nil { - return - } - return *v, true +// Type returns the node type of this mutation (IdempotencyRecord). +func (m *IdempotencyRecordMutation) Type() string { + return m.typ } -// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUserNotes is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUserNotes requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUserNotes: %w", err) +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdempotencyRecordMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, idempotencyrecord.FieldCreatedAt) } - return oldValue.UserNotes, nil -} - -// ClearUserNotes clears the value of the "user_notes" field. -func (m *PaymentOrderMutation) ClearUserNotes() { - m.user_notes = nil - m.clearedFields[paymentorder.FieldUserNotes] = struct{}{} -} - -// UserNotesCleared returns if the "user_notes" field was cleared in this mutation. -func (m *PaymentOrderMutation) UserNotesCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldUserNotes] - return ok -} - -// ResetUserNotes resets all changes to the "user_notes" field. -func (m *PaymentOrderMutation) ResetUserNotes() { - m.user_notes = nil - delete(m.clearedFields, paymentorder.FieldUserNotes) -} - -// SetAmount sets the "amount" field. -func (m *PaymentOrderMutation) SetAmount(f float64) { - m.amount = &f - m.addamount = nil -} - -// Amount returns the value of the "amount" field in the mutation. -func (m *PaymentOrderMutation) Amount() (r float64, exists bool) { - v := m.amount - if v == nil { - return + if m.updated_at != nil { + fields = append(fields, idempotencyrecord.FieldUpdatedAt) } - return *v, true -} - -// OldAmount returns the old "amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAmount is only allowed on UpdateOne operations") + if m.scope != nil { + fields = append(fields, idempotencyrecord.FieldScope) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAmount requires an ID field in the mutation") + if m.idempotency_key_hash != nil { + fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldAmount: %w", err) + if m.request_fingerprint != nil { + fields = append(fields, idempotencyrecord.FieldRequestFingerprint) } - return oldValue.Amount, nil -} - -// AddAmount adds f to the "amount" field. -func (m *PaymentOrderMutation) AddAmount(f float64) { - if m.addamount != nil { - *m.addamount += f - } else { - m.addamount = &f + if m.status != nil { + fields = append(fields, idempotencyrecord.FieldStatus) } -} - -// AddedAmount returns the value that was added to the "amount" field in this mutation. -func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) { - v := m.addamount - if v == nil { - return + if m.response_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) } - return *v, true -} - -// ResetAmount resets all changes to the "amount" field. -func (m *PaymentOrderMutation) ResetAmount() { - m.amount = nil - m.addamount = nil -} - -// SetPayAmount sets the "pay_amount" field. -func (m *PaymentOrderMutation) SetPayAmount(f float64) { - m.pay_amount = &f - m.addpay_amount = nil -} - -// PayAmount returns the value of the "pay_amount" field in the mutation. -func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) { - v := m.pay_amount - if v == nil { - return + if m.response_body != nil { + fields = append(fields, idempotencyrecord.FieldResponseBody) } - return *v, true -} - -// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPayAmount is only allowed on UpdateOne operations") + if m.error_reason != nil { + fields = append(fields, idempotencyrecord.FieldErrorReason) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPayAmount requires an ID field in the mutation") + if m.locked_until != nil { + fields = append(fields, idempotencyrecord.FieldLockedUntil) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPayAmount: %w", err) + if m.expires_at != nil { + fields = append(fields, idempotencyrecord.FieldExpiresAt) } - return oldValue.PayAmount, nil + return fields } -// AddPayAmount adds f to the "pay_amount" field. -func (m *PaymentOrderMutation) AddPayAmount(f float64) { - if m.addpay_amount != nil { - *m.addpay_amount += f - } else { - m.addpay_amount = &f +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.CreatedAt() + case idempotencyrecord.FieldUpdatedAt: + return m.UpdatedAt() + case idempotencyrecord.FieldScope: + return m.Scope() + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.IdempotencyKeyHash() + case idempotencyrecord.FieldRequestFingerprint: + return m.RequestFingerprint() + case idempotencyrecord.FieldStatus: + return m.Status() + case idempotencyrecord.FieldResponseStatus: + return m.ResponseStatus() + case idempotencyrecord.FieldResponseBody: + return m.ResponseBody() + case idempotencyrecord.FieldErrorReason: + return m.ErrorReason() + case idempotencyrecord.FieldLockedUntil: + return m.LockedUntil() + case idempotencyrecord.FieldExpiresAt: + return m.ExpiresAt() } + return nil, false } -// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation. -func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) { - v := m.addpay_amount - if v == nil { - return +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case idempotencyrecord.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case idempotencyrecord.FieldScope: + return m.OldScope(ctx) + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.OldIdempotencyKeyHash(ctx) + case idempotencyrecord.FieldRequestFingerprint: + return m.OldRequestFingerprint(ctx) + case idempotencyrecord.FieldStatus: + return m.OldStatus(ctx) + case idempotencyrecord.FieldResponseStatus: + return m.OldResponseStatus(ctx) + case idempotencyrecord.FieldResponseBody: + return m.OldResponseBody(ctx) + case idempotencyrecord.FieldErrorReason: + return m.OldErrorReason(ctx) + case idempotencyrecord.FieldLockedUntil: + return m.OldLockedUntil(ctx) + case idempotencyrecord.FieldExpiresAt: + return m.OldExpiresAt(ctx) } - return *v, true -} - -// ResetPayAmount resets all changes to the "pay_amount" field. -func (m *PaymentOrderMutation) ResetPayAmount() { - m.pay_amount = nil - m.addpay_amount = nil + return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) } -// SetFeeRate sets the "fee_rate" field. -func (m *PaymentOrderMutation) SetFeeRate(f float64) { - m.fee_rate = &f - m.addfee_rate = nil -} - -// FeeRate returns the value of the "fee_rate" field in the mutation. -func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) { - v := m.fee_rate - if v == nil { - return +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case idempotencyrecord.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case idempotencyrecord.FieldScope: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScope(v) + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdempotencyKeyHash(v) + return nil + case idempotencyrecord.FieldRequestFingerprint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestFingerprint(v) + return nil + case idempotencyrecord.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseStatus(v) + return nil + case idempotencyrecord.FieldResponseBody: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseBody(v) + return nil + case idempotencyrecord.FieldErrorReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorReason(v) + return nil + case idempotencyrecord.FieldLockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLockedUntil(v) + return nil + case idempotencyrecord.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil } - return *v, true + return fmt.Errorf("unknown IdempotencyRecord field %s", name) } -// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFeeRate is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFeeRate requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFeeRate: %w", err) +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdempotencyRecordMutation) AddedFields() []string { + var fields []string + if m.addresponse_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) } - return oldValue.FeeRate, nil + return fields } -// AddFeeRate adds f to the "fee_rate" field. -func (m *PaymentOrderMutation) AddFeeRate(f float64) { - if m.addfee_rate != nil { - *m.addfee_rate += f - } else { - m.addfee_rate = &f +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldResponseStatus: + return m.AddedResponseStatus() } + return nil, false } -// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation. -func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) { - v := m.addfee_rate - if v == nil { - return +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseStatus(v) + return nil } - return *v, true -} - -// ResetFeeRate resets all changes to the "fee_rate" field. -func (m *PaymentOrderMutation) ResetFeeRate() { - m.fee_rate = nil - m.addfee_rate = nil -} - -// SetRechargeCode sets the "recharge_code" field. -func (m *PaymentOrderMutation) SetRechargeCode(s string) { - m.recharge_code = &s + return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) } -// RechargeCode returns the value of the "recharge_code" field in the mutation. -func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) { - v := m.recharge_code - if v == nil { - return +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdempotencyRecordMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { + fields = append(fields, idempotencyrecord.FieldResponseStatus) } - return *v, true -} - -// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations") + if m.FieldCleared(idempotencyrecord.FieldResponseBody) { + fields = append(fields, idempotencyrecord.FieldResponseBody) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRechargeCode requires an ID field in the mutation") + if m.FieldCleared(idempotencyrecord.FieldErrorReason) { + fields = append(fields, idempotencyrecord.FieldErrorReason) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err) + if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { + fields = append(fields, idempotencyrecord.FieldLockedUntil) } - return oldValue.RechargeCode, nil -} - -// ResetRechargeCode resets all changes to the "recharge_code" field. -func (m *PaymentOrderMutation) ResetRechargeCode() { - m.recharge_code = nil -} - -// SetOutTradeNo sets the "out_trade_no" field. -func (m *PaymentOrderMutation) SetOutTradeNo(s string) { - m.out_trade_no = &s + return fields } -// OutTradeNo returns the value of the "out_trade_no" field in the mutation. -func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) { - v := m.out_trade_no - if v == nil { - return - } - return *v, true +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOutTradeNo requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err) +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearField(name string) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + m.ClearResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ClearResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ClearErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ClearLockedUntil() + return nil } - return oldValue.OutTradeNo, nil -} - -// ResetOutTradeNo resets all changes to the "out_trade_no" field. -func (m *PaymentOrderMutation) ResetOutTradeNo() { - m.out_trade_no = nil -} - -// SetPaymentType sets the "payment_type" field. -func (m *PaymentOrderMutation) SetPaymentType(s string) { - m.payment_type = &s + return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) } -// PaymentType returns the value of the "payment_type" field in the mutation. -func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) { - v := m.payment_type +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetField(name string) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case idempotencyrecord.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case idempotencyrecord.FieldScope: + m.ResetScope() + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + m.ResetIdempotencyKeyHash() + return nil + case idempotencyrecord.FieldRequestFingerprint: + m.ResetRequestFingerprint() + return nil + case idempotencyrecord.FieldStatus: + m.ResetStatus() + return nil + case idempotencyrecord.FieldResponseStatus: + m.ResetResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ResetResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ResetErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ResetLockedUntil() + return nil + case idempotencyrecord.FieldExpiresAt: + m.ResetExpiresAt() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdempotencyRecordMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdempotencyRecordMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdempotencyRecordMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +} + +// IdentityAdoptionDecisionMutation represents an operation that mutates the IdentityAdoptionDecision nodes in the graph. +type IdentityAdoptionDecisionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + adopt_display_name *bool + adopt_avatar *bool + decided_at *time.Time + clearedFields map[string]struct{} + pending_auth_session *int64 + clearedpending_auth_session bool + identity *int64 + clearedidentity bool + done bool + oldValue func(context.Context) (*IdentityAdoptionDecision, error) + predicates []predicate.IdentityAdoptionDecision +} + +var _ ent.Mutation = (*IdentityAdoptionDecisionMutation)(nil) + +// identityadoptiondecisionOption allows management of the mutation configuration using functional options. +type identityadoptiondecisionOption func(*IdentityAdoptionDecisionMutation) + +// newIdentityAdoptionDecisionMutation creates new mutation for the IdentityAdoptionDecision entity. +func newIdentityAdoptionDecisionMutation(c config, op Op, opts ...identityadoptiondecisionOption) *IdentityAdoptionDecisionMutation { + m := &IdentityAdoptionDecisionMutation{ + config: c, + op: op, + typ: TypeIdentityAdoptionDecision, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdentityAdoptionDecisionID sets the ID field of the mutation. +func withIdentityAdoptionDecisionID(id int64) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + var ( + err error + once sync.Once + value *IdentityAdoptionDecision + ) + m.oldValue = func(ctx context.Context) (*IdentityAdoptionDecision, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdentityAdoptionDecision.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdentityAdoptionDecision sets the old IdentityAdoptionDecision of the mutation. +func withIdentityAdoptionDecision(node *IdentityAdoptionDecision) identityadoptiondecisionOption { + return func(m *IdentityAdoptionDecisionMutation) { + m.oldValue = func(context.Context) (*IdentityAdoptionDecision, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdentityAdoptionDecisionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdentityAdoptionDecisionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdentityAdoptionDecisionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdentityAdoptionDecisionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdentityAdoptionDecision.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentType is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentType requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentType: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.PaymentType, nil + return oldValue.CreatedAt, nil } -// ResetPaymentType resets all changes to the "payment_type" field. -func (m *PaymentOrderMutation) ResetPaymentType() { - m.payment_type = nil +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetCreatedAt() { + m.created_at = nil } -// SetPaymentTradeNo sets the "payment_trade_no" field. -func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) { - m.payment_trade_no = &s +// SetUpdatedAt sets the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t } -// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation. -func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) { - v := m.payment_trade_no +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at if v == nil { return } return *v, true } -// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldUpdatedAt returns the old "updated_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations") + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation") + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err) + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) } - return oldValue.PaymentTradeNo, nil + return oldValue.UpdatedAt, nil } -// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field. -func (m *PaymentOrderMutation) ResetPaymentTradeNo() { - m.payment_trade_no = nil +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetUpdatedAt() { + m.updated_at = nil } -// SetPayURL sets the "pay_url" field. -func (m *PaymentOrderMutation) SetPayURL(s string) { - m.pay_url = &s +// SetPendingAuthSessionID sets the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) SetPendingAuthSessionID(i int64) { + m.pending_auth_session = &i } -// PayURL returns the value of the "pay_url" field in the mutation. -func (m *PaymentOrderMutation) PayURL() (r string, exists bool) { - v := m.pay_url +// PendingAuthSessionID returns the value of the "pending_auth_session_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionID() (r int64, exists bool) { + v := m.pending_auth_session if v == nil { return } return *v, true } -// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldPendingAuthSessionID returns the old "pending_auth_session_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldPendingAuthSessionID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPayURL is only allowed on UpdateOne operations") + return v, errors.New("OldPendingAuthSessionID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPayURL requires an ID field in the mutation") + return v, errors.New("OldPendingAuthSessionID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPayURL: %w", err) + return v, fmt.Errorf("querying old value for OldPendingAuthSessionID: %w", err) } - return oldValue.PayURL, nil + return oldValue.PendingAuthSessionID, nil } -// ClearPayURL clears the value of the "pay_url" field. -func (m *PaymentOrderMutation) ClearPayURL() { - m.pay_url = nil - m.clearedFields[paymentorder.FieldPayURL] = struct{}{} +// ResetPendingAuthSessionID resets all changes to the "pending_auth_session_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSessionID() { + m.pending_auth_session = nil } -// PayURLCleared returns if the "pay_url" field was cleared in this mutation. -func (m *PaymentOrderMutation) PayURLCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPayURL] - return ok +// SetIdentityID sets the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) SetIdentityID(i int64) { + m.identity = &i } -// ResetPayURL resets all changes to the "pay_url" field. -func (m *PaymentOrderMutation) ResetPayURL() { - m.pay_url = nil - delete(m.clearedFields, paymentorder.FieldPayURL) +// IdentityID returns the value of the "identity_id" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityID() (r int64, exists bool) { + v := m.identity + if v == nil { + return + } + return *v, true } -// SetQrCode sets the "qr_code" field. -func (m *PaymentOrderMutation) SetQrCode(s string) { - m.qr_code = &s -} - -// QrCode returns the value of the "qr_code" field in the mutation. -func (m *PaymentOrderMutation) QrCode() (r string, exists bool) { - v := m.qr_code - if v == nil { - return - } - return *v, true -} - -// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldIdentityID returns the old "identity_id" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldIdentityID(ctx context.Context) (v *int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldQrCode is only allowed on UpdateOne operations") + return v, errors.New("OldIdentityID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldQrCode requires an ID field in the mutation") + return v, errors.New("OldIdentityID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldQrCode: %w", err) + return v, fmt.Errorf("querying old value for OldIdentityID: %w", err) } - return oldValue.QrCode, nil + return oldValue.IdentityID, nil } -// ClearQrCode clears the value of the "qr_code" field. -func (m *PaymentOrderMutation) ClearQrCode() { - m.qr_code = nil - m.clearedFields[paymentorder.FieldQrCode] = struct{}{} +// ClearIdentityID clears the value of the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ClearIdentityID() { + m.identity = nil + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} } -// QrCodeCleared returns if the "qr_code" field was cleared in this mutation. -func (m *PaymentOrderMutation) QrCodeCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldQrCode] +// IdentityIDCleared returns if the "identity_id" field was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) IdentityIDCleared() bool { + _, ok := m.clearedFields[identityadoptiondecision.FieldIdentityID] return ok } -// ResetQrCode resets all changes to the "qr_code" field. -func (m *PaymentOrderMutation) ResetQrCode() { - m.qr_code = nil - delete(m.clearedFields, paymentorder.FieldQrCode) +// ResetIdentityID resets all changes to the "identity_id" field. +func (m *IdentityAdoptionDecisionMutation) ResetIdentityID() { + m.identity = nil + delete(m.clearedFields, identityadoptiondecision.FieldIdentityID) } -// SetQrCodeImg sets the "qr_code_img" field. -func (m *PaymentOrderMutation) SetQrCodeImg(s string) { - m.qr_code_img = &s +// SetAdoptDisplayName sets the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptDisplayName(b bool) { + m.adopt_display_name = &b } -// QrCodeImg returns the value of the "qr_code_img" field in the mutation. -func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) { - v := m.qr_code_img +// AdoptDisplayName returns the value of the "adopt_display_name" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptDisplayName() (r bool, exists bool) { + v := m.adopt_display_name if v == nil { return } return *v, true } -// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldAdoptDisplayName returns the old "adopt_display_name" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldAdoptDisplayName(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations") + return v, errors.New("OldAdoptDisplayName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldQrCodeImg requires an ID field in the mutation") + return v, errors.New("OldAdoptDisplayName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err) + return v, fmt.Errorf("querying old value for OldAdoptDisplayName: %w", err) } - return oldValue.QrCodeImg, nil -} - -// ClearQrCodeImg clears the value of the "qr_code_img" field. -func (m *PaymentOrderMutation) ClearQrCodeImg() { - m.qr_code_img = nil - m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{} -} - -// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation. -func (m *PaymentOrderMutation) QrCodeImgCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldQrCodeImg] - return ok + return oldValue.AdoptDisplayName, nil } -// ResetQrCodeImg resets all changes to the "qr_code_img" field. -func (m *PaymentOrderMutation) ResetQrCodeImg() { - m.qr_code_img = nil - delete(m.clearedFields, paymentorder.FieldQrCodeImg) +// ResetAdoptDisplayName resets all changes to the "adopt_display_name" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptDisplayName() { + m.adopt_display_name = nil } -// SetOrderType sets the "order_type" field. -func (m *PaymentOrderMutation) SetOrderType(s string) { - m.order_type = &s +// SetAdoptAvatar sets the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) SetAdoptAvatar(b bool) { + m.adopt_avatar = &b } -// OrderType returns the value of the "order_type" field in the mutation. -func (m *PaymentOrderMutation) OrderType() (r string, exists bool) { - v := m.order_type +// AdoptAvatar returns the value of the "adopt_avatar" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) AdoptAvatar() (r bool, exists bool) { + v := m.adopt_avatar if v == nil { return } return *v, true } -// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldAdoptAvatar returns the old "adopt_avatar" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) { +func (m *IdentityAdoptionDecisionMutation) OldAdoptAvatar(ctx context.Context) (v bool, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldOrderType is only allowed on UpdateOne operations") + return v, errors.New("OldAdoptAvatar is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldOrderType requires an ID field in the mutation") + return v, errors.New("OldAdoptAvatar requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldOrderType: %w", err) + return v, fmt.Errorf("querying old value for OldAdoptAvatar: %w", err) } - return oldValue.OrderType, nil + return oldValue.AdoptAvatar, nil } -// ResetOrderType resets all changes to the "order_type" field. -func (m *PaymentOrderMutation) ResetOrderType() { - m.order_type = nil +// ResetAdoptAvatar resets all changes to the "adopt_avatar" field. +func (m *IdentityAdoptionDecisionMutation) ResetAdoptAvatar() { + m.adopt_avatar = nil } -// SetPlanID sets the "plan_id" field. -func (m *PaymentOrderMutation) SetPlanID(i int64) { - m.plan_id = &i - m.addplan_id = nil +// SetDecidedAt sets the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) SetDecidedAt(t time.Time) { + m.decided_at = &t } -// PlanID returns the value of the "plan_id" field in the mutation. -func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) { - v := m.plan_id +// DecidedAt returns the value of the "decided_at" field in the mutation. +func (m *IdentityAdoptionDecisionMutation) DecidedAt() (r time.Time, exists bool) { + v := m.decided_at if v == nil { return } return *v, true } -// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldDecidedAt returns the old "decided_at" field's value of the IdentityAdoptionDecision entity. +// If the IdentityAdoptionDecision object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) { +func (m *IdentityAdoptionDecisionMutation) OldDecidedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPlanID is only allowed on UpdateOne operations") + return v, errors.New("OldDecidedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPlanID requires an ID field in the mutation") + return v, errors.New("OldDecidedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPlanID: %w", err) + return v, fmt.Errorf("querying old value for OldDecidedAt: %w", err) } - return oldValue.PlanID, nil + return oldValue.DecidedAt, nil } -// AddPlanID adds i to the "plan_id" field. -func (m *PaymentOrderMutation) AddPlanID(i int64) { - if m.addplan_id != nil { - *m.addplan_id += i - } else { - m.addplan_id = &i - } +// ResetDecidedAt resets all changes to the "decided_at" field. +func (m *IdentityAdoptionDecisionMutation) ResetDecidedAt() { + m.decided_at = nil } -// AddedPlanID returns the value that was added to the "plan_id" field in this mutation. -func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) { - v := m.addplan_id - if v == nil { - return - } - return *v, true +// ClearPendingAuthSession clears the "pending_auth_session" edge to the PendingAuthSession entity. +func (m *IdentityAdoptionDecisionMutation) ClearPendingAuthSession() { + m.clearedpending_auth_session = true + m.clearedFields[identityadoptiondecision.FieldPendingAuthSessionID] = struct{}{} } -// ClearPlanID clears the value of the "plan_id" field. -func (m *PaymentOrderMutation) ClearPlanID() { - m.plan_id = nil - m.addplan_id = nil - m.clearedFields[paymentorder.FieldPlanID] = struct{}{} +// PendingAuthSessionCleared reports if the "pending_auth_session" edge to the PendingAuthSession entity was cleared. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionCleared() bool { + return m.clearedpending_auth_session } -// PlanIDCleared returns if the "plan_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) PlanIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPlanID] - return ok +// PendingAuthSessionIDs returns the "pending_auth_session" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// PendingAuthSessionID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) PendingAuthSessionIDs() (ids []int64) { + if id := m.pending_auth_session; id != nil { + ids = append(ids, *id) + } + return } -// ResetPlanID resets all changes to the "plan_id" field. -func (m *PaymentOrderMutation) ResetPlanID() { - m.plan_id = nil - m.addplan_id = nil - delete(m.clearedFields, paymentorder.FieldPlanID) +// ResetPendingAuthSession resets all changes to the "pending_auth_session" edge. +func (m *IdentityAdoptionDecisionMutation) ResetPendingAuthSession() { + m.pending_auth_session = nil + m.clearedpending_auth_session = false } -// SetSubscriptionGroupID sets the "subscription_group_id" field. -func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) { - m.subscription_group_id = &i - m.addsubscription_group_id = nil +// ClearIdentity clears the "identity" edge to the AuthIdentity entity. +func (m *IdentityAdoptionDecisionMutation) ClearIdentity() { + m.clearedidentity = true + m.clearedFields[identityadoptiondecision.FieldIdentityID] = struct{}{} } -// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation. -func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) { - v := m.subscription_group_id - if v == nil { - return - } - return *v, true +// IdentityCleared reports if the "identity" edge to the AuthIdentity entity was cleared. +func (m *IdentityAdoptionDecisionMutation) IdentityCleared() bool { + return m.IdentityIDCleared() || m.clearedidentity } -// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err) +// IdentityIDs returns the "identity" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// IdentityID instead. It exists only for internal usage by the builders. +func (m *IdentityAdoptionDecisionMutation) IdentityIDs() (ids []int64) { + if id := m.identity; id != nil { + ids = append(ids, *id) } - return oldValue.SubscriptionGroupID, nil + return } -// AddSubscriptionGroupID adds i to the "subscription_group_id" field. -func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) { - if m.addsubscription_group_id != nil { - *m.addsubscription_group_id += i - } else { - m.addsubscription_group_id = &i - } +// ResetIdentity resets all changes to the "identity" edge. +func (m *IdentityAdoptionDecisionMutation) ResetIdentity() { + m.identity = nil + m.clearedidentity = false } -// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation. -func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) { - v := m.addsubscription_group_id - if v == nil { - return - } - return *v, true +// Where appends a list predicates to the IdentityAdoptionDecisionMutation builder. +func (m *IdentityAdoptionDecisionMutation) Where(ps ...predicate.IdentityAdoptionDecision) { + m.predicates = append(m.predicates, ps...) } -// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field. -func (m *PaymentOrderMutation) ClearSubscriptionGroupID() { - m.subscription_group_id = nil - m.addsubscription_group_id = nil - m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{} +// WhereP appends storage-level predicates to the IdentityAdoptionDecisionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdentityAdoptionDecisionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdentityAdoptionDecision, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID] - return ok +// Op returns the operation name. +func (m *IdentityAdoptionDecisionMutation) Op() Op { + return m.op } -// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field. -func (m *PaymentOrderMutation) ResetSubscriptionGroupID() { - m.subscription_group_id = nil - m.addsubscription_group_id = nil - delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID) +// SetOp allows setting the mutation operation. +func (m *IdentityAdoptionDecisionMutation) SetOp(op Op) { + m.op = op } -// SetSubscriptionDays sets the "subscription_days" field. -func (m *PaymentOrderMutation) SetSubscriptionDays(i int) { - m.subscription_days = &i - m.addsubscription_days = nil +// Type returns the node type of this mutation (IdentityAdoptionDecision). +func (m *IdentityAdoptionDecisionMutation) Type() string { + return m.typ } -// SubscriptionDays returns the value of the "subscription_days" field in the mutation. -func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) { - v := m.subscription_days - if v == nil { - return +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdentityAdoptionDecisionMutation) Fields() []string { + fields := make([]string, 0, 7) + if m.created_at != nil { + fields = append(fields, identityadoptiondecision.FieldCreatedAt) } - return *v, true -} - -// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations") + if m.updated_at != nil { + fields = append(fields, identityadoptiondecision.FieldUpdatedAt) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSubscriptionDays requires an ID field in the mutation") + if m.pending_auth_session != nil { + fields = append(fields, identityadoptiondecision.FieldPendingAuthSessionID) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err) + if m.identity != nil { + fields = append(fields, identityadoptiondecision.FieldIdentityID) } - return oldValue.SubscriptionDays, nil -} - -// AddSubscriptionDays adds i to the "subscription_days" field. -func (m *PaymentOrderMutation) AddSubscriptionDays(i int) { - if m.addsubscription_days != nil { - *m.addsubscription_days += i - } else { - m.addsubscription_days = &i + if m.adopt_display_name != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptDisplayName) + } + if m.adopt_avatar != nil { + fields = append(fields, identityadoptiondecision.FieldAdoptAvatar) } + if m.decided_at != nil { + fields = append(fields, identityadoptiondecision.FieldDecidedAt) + } + return fields } -// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation. -func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) { - v := m.addsubscription_days - if v == nil { - return +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdentityAdoptionDecisionMutation) Field(name string) (ent.Value, bool) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.CreatedAt() + case identityadoptiondecision.FieldUpdatedAt: + return m.UpdatedAt() + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.PendingAuthSessionID() + case identityadoptiondecision.FieldIdentityID: + return m.IdentityID() + case identityadoptiondecision.FieldAdoptDisplayName: + return m.AdoptDisplayName() + case identityadoptiondecision.FieldAdoptAvatar: + return m.AdoptAvatar() + case identityadoptiondecision.FieldDecidedAt: + return m.DecidedAt() } - return *v, true + return nil, false } -// ClearSubscriptionDays clears the value of the "subscription_days" field. -func (m *PaymentOrderMutation) ClearSubscriptionDays() { - m.subscription_days = nil - m.addsubscription_days = nil - m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{} +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdentityAdoptionDecisionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case identityadoptiondecision.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case identityadoptiondecision.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case identityadoptiondecision.FieldPendingAuthSessionID: + return m.OldPendingAuthSessionID(ctx) + case identityadoptiondecision.FieldIdentityID: + return m.OldIdentityID(ctx) + case identityadoptiondecision.FieldAdoptDisplayName: + return m.OldAdoptDisplayName(ctx) + case identityadoptiondecision.FieldAdoptAvatar: + return m.OldAdoptAvatar(ctx) + case identityadoptiondecision.FieldDecidedAt: + return m.OldDecidedAt(ctx) + } + return nil, fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) } -// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation. -func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays] - return ok -} - -// ResetSubscriptionDays resets all changes to the "subscription_days" field. -func (m *PaymentOrderMutation) ResetSubscriptionDays() { - m.subscription_days = nil - m.addsubscription_days = nil - delete(m.clearedFields, paymentorder.FieldSubscriptionDays) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) SetField(name string, value ent.Value) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case identityadoptiondecision.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPendingAuthSessionID(v) + return nil + case identityadoptiondecision.FieldIdentityID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdentityID(v) + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptDisplayName(v) + return nil + case identityadoptiondecision.FieldAdoptAvatar: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAdoptAvatar(v) + return nil + case identityadoptiondecision.FieldDecidedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDecidedAt(v) + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) } -// SetProviderInstanceID sets the "provider_instance_id" field. -func (m *PaymentOrderMutation) SetProviderInstanceID(s string) { - m.provider_instance_id = &s +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedFields() []string { + var fields []string + return fields } -// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation. -func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) { - v := m.provider_instance_id - if v == nil { - return +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) AddedField(name string) (ent.Value, bool) { + switch name { } - return *v, true + return nil, false } -// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldProviderInstanceID requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err) +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdentityAdoptionDecisionMutation) AddField(name string, value ent.Value) error { + switch name { } - return oldValue.ProviderInstanceID, nil + return fmt.Errorf("unknown IdentityAdoptionDecision numeric field %s", name) } -// ClearProviderInstanceID clears the value of the "provider_instance_id" field. -func (m *PaymentOrderMutation) ClearProviderInstanceID() { - m.provider_instance_id = nil - m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{} +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(identityadoptiondecision.FieldIdentityID) { + fields = append(fields, identityadoptiondecision.FieldIdentityID) + } + return fields } -// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation. -func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID] +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] return ok } -// ResetProviderInstanceID resets all changes to the "provider_instance_id" field. -func (m *PaymentOrderMutation) ResetProviderInstanceID() { - m.provider_instance_id = nil - delete(m.clearedFields, paymentorder.FieldProviderInstanceID) -} - -// SetStatus sets the "status" field. -func (m *PaymentOrderMutation) SetStatus(s string) { - m.status = &s +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearField(name string) error { + switch name { + case identityadoptiondecision.FieldIdentityID: + m.ClearIdentityID() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision nullable field %s", name) } -// Status returns the value of the "status" field in the mutation. -func (m *PaymentOrderMutation) Status() (r string, exists bool) { - v := m.status - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetField(name string) error { + switch name { + case identityadoptiondecision.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case identityadoptiondecision.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case identityadoptiondecision.FieldPendingAuthSessionID: + m.ResetPendingAuthSessionID() + return nil + case identityadoptiondecision.FieldIdentityID: + m.ResetIdentityID() + return nil + case identityadoptiondecision.FieldAdoptDisplayName: + m.ResetAdoptDisplayName() + return nil + case identityadoptiondecision.FieldAdoptAvatar: + m.ResetAdoptAvatar() + return nil + case identityadoptiondecision.FieldDecidedAt: + m.ResetDecidedAt() + return nil } - return *v, true + return fmt.Errorf("unknown IdentityAdoptionDecision field %s", name) } -// OldStatus returns the old "status" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldStatus is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldStatus requires an ID field in the mutation") +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.pending_auth_session != nil { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldStatus: %w", err) + if m.identity != nil { + edges = append(edges, identityadoptiondecision.EdgeIdentity) } - return oldValue.Status, nil + return edges } -// ResetStatus resets all changes to the "status" field. -func (m *PaymentOrderMutation) ResetStatus() { - m.status = nil +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdentityAdoptionDecisionMutation) AddedIDs(name string) []ent.Value { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + if id := m.pending_auth_session; id != nil { + return []ent.Value{*id} + } + case identityadoptiondecision.EdgeIdentity: + if id := m.identity; id != nil { + return []ent.Value{*id} + } + } + return nil } -// SetRefundAmount sets the "refund_amount" field. -func (m *PaymentOrderMutation) SetRefundAmount(f float64) { - m.refund_amount = &f - m.addrefund_amount = nil +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + return edges } -// RefundAmount returns the value of the "refund_amount" field in the mutation. -func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) { - v := m.refund_amount - if v == nil { - return - } - return *v, true +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdentityAdoptionDecisionMutation) RemovedIDs(name string) []ent.Value { + return nil } -// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundAmount requires an ID field in the mutation") +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedpending_auth_session { + edges = append(edges, identityadoptiondecision.EdgePendingAuthSession) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err) + if m.clearedidentity { + edges = append(edges, identityadoptiondecision.EdgeIdentity) } - return oldValue.RefundAmount, nil + return edges } -// AddRefundAmount adds f to the "refund_amount" field. -func (m *PaymentOrderMutation) AddRefundAmount(f float64) { - if m.addrefund_amount != nil { - *m.addrefund_amount += f - } else { - m.addrefund_amount = &f +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdentityAdoptionDecisionMutation) EdgeCleared(name string) bool { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + return m.clearedpending_auth_session + case identityadoptiondecision.EdgeIdentity: + return m.clearedidentity } + return false } -// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation. -func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) { - v := m.addrefund_amount - if v == nil { - return +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ClearEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ClearPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ClearIdentity() + return nil } - return *v, true + return fmt.Errorf("unknown IdentityAdoptionDecision unique edge %s", name) } -// ResetRefundAmount resets all changes to the "refund_amount" field. -func (m *PaymentOrderMutation) ResetRefundAmount() { - m.refund_amount = nil - m.addrefund_amount = nil +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdentityAdoptionDecisionMutation) ResetEdge(name string) error { + switch name { + case identityadoptiondecision.EdgePendingAuthSession: + m.ResetPendingAuthSession() + return nil + case identityadoptiondecision.EdgeIdentity: + m.ResetIdentity() + return nil + } + return fmt.Errorf("unknown IdentityAdoptionDecision edge %s", name) } -// SetRefundReason sets the "refund_reason" field. -func (m *PaymentOrderMutation) SetRefundReason(s string) { - m.refund_reason = &s +// PaymentAuditLogMutation represents an operation that mutates the PaymentAuditLog nodes in the graph. +type PaymentAuditLogMutation struct { + config + op Op + typ string + id *int64 + order_id *string + action *string + detail *string + operator *string + created_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentAuditLog, error) + predicates []predicate.PaymentAuditLog } -// RefundReason returns the value of the "refund_reason" field in the mutation. -func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) { - v := m.refund_reason - if v == nil { - return - } - return *v, true -} +var _ ent.Mutation = (*PaymentAuditLogMutation)(nil) -// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundReason is only allowed on UpdateOne operations") +// paymentauditlogOption allows management of the mutation configuration using functional options. +type paymentauditlogOption func(*PaymentAuditLogMutation) + +// newPaymentAuditLogMutation creates new mutation for the PaymentAuditLog entity. +func newPaymentAuditLogMutation(c config, op Op, opts ...paymentauditlogOption) *PaymentAuditLogMutation { + m := &PaymentAuditLogMutation{ + config: c, + op: op, + typ: TypePaymentAuditLog, + clearedFields: make(map[string]struct{}), } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundReason requires an ID field in the mutation") + for _, opt := range opts { + opt(m) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRefundReason: %w", err) + return m +} + +// withPaymentAuditLogID sets the ID field of the mutation. +func withPaymentAuditLogID(id int64) paymentauditlogOption { + return func(m *PaymentAuditLogMutation) { + var ( + err error + once sync.Once + value *PaymentAuditLog + ) + m.oldValue = func(ctx context.Context) (*PaymentAuditLog, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentAuditLog.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - return oldValue.RefundReason, nil } -// ClearRefundReason clears the value of the "refund_reason" field. -func (m *PaymentOrderMutation) ClearRefundReason() { - m.refund_reason = nil - m.clearedFields[paymentorder.FieldRefundReason] = struct{}{} +// withPaymentAuditLog sets the old PaymentAuditLog of the mutation. +func withPaymentAuditLog(node *PaymentAuditLog) paymentauditlogOption { + return func(m *PaymentAuditLogMutation) { + m.oldValue = func(context.Context) (*PaymentAuditLog, error) { + return node, nil + } + m.id = &node.ID + } } -// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundReason] - return ok +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentAuditLogMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// ResetRefundReason resets all changes to the "refund_reason" field. -func (m *PaymentOrderMutation) ResetRefundReason() { - m.refund_reason = nil - delete(m.clearedFields, paymentorder.FieldRefundReason) +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentAuditLogMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// SetRefundAt sets the "refund_at" field. -func (m *PaymentOrderMutation) SetRefundAt(t time.Time) { - m.refund_at = &t +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentAuditLogMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true } -// RefundAt returns the value of the "refund_at" field in the mutation. -func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) { - v := m.refund_at +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentAuditLogMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentAuditLog.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetOrderID sets the "order_id" field. +func (m *PaymentAuditLogMutation) SetOrderID(s string) { + m.order_id = &s +} + +// OrderID returns the value of the "order_id" field in the mutation. +func (m *PaymentAuditLogMutation) OrderID() (r string, exists bool) { + v := m.order_id if v == nil { return } return *v, true } -// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldOrderID returns the old "order_id" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) { +func (m *PaymentAuditLogMutation) OldOrderID(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundAt is only allowed on UpdateOne operations") + return v, errors.New("OldOrderID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundAt requires an ID field in the mutation") + return v, errors.New("OldOrderID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundAt: %w", err) + return v, fmt.Errorf("querying old value for OldOrderID: %w", err) } - return oldValue.RefundAt, nil -} - -// ClearRefundAt clears the value of the "refund_at" field. -func (m *PaymentOrderMutation) ClearRefundAt() { - m.refund_at = nil - m.clearedFields[paymentorder.FieldRefundAt] = struct{}{} -} - -// RefundAtCleared returns if the "refund_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundAt] - return ok + return oldValue.OrderID, nil } -// ResetRefundAt resets all changes to the "refund_at" field. -func (m *PaymentOrderMutation) ResetRefundAt() { - m.refund_at = nil - delete(m.clearedFields, paymentorder.FieldRefundAt) +// ResetOrderID resets all changes to the "order_id" field. +func (m *PaymentAuditLogMutation) ResetOrderID() { + m.order_id = nil } -// SetForceRefund sets the "force_refund" field. -func (m *PaymentOrderMutation) SetForceRefund(b bool) { - m.force_refund = &b +// SetAction sets the "action" field. +func (m *PaymentAuditLogMutation) SetAction(s string) { + m.action = &s } -// ForceRefund returns the value of the "force_refund" field in the mutation. -func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) { - v := m.force_refund +// Action returns the value of the "action" field in the mutation. +func (m *PaymentAuditLogMutation) Action() (r string, exists bool) { + v := m.action if v == nil { return } return *v, true } -// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldAction returns the old "action" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) { +func (m *PaymentAuditLogMutation) OldAction(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldForceRefund is only allowed on UpdateOne operations") + return v, errors.New("OldAction is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldForceRefund requires an ID field in the mutation") + return v, errors.New("OldAction requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldForceRefund: %w", err) + return v, fmt.Errorf("querying old value for OldAction: %w", err) } - return oldValue.ForceRefund, nil + return oldValue.Action, nil } -// ResetForceRefund resets all changes to the "force_refund" field. -func (m *PaymentOrderMutation) ResetForceRefund() { - m.force_refund = nil +// ResetAction resets all changes to the "action" field. +func (m *PaymentAuditLogMutation) ResetAction() { + m.action = nil } -// SetRefundRequestedAt sets the "refund_requested_at" field. -func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) { - m.refund_requested_at = &t +// SetDetail sets the "detail" field. +func (m *PaymentAuditLogMutation) SetDetail(s string) { + m.detail = &s } -// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) { - v := m.refund_requested_at +// Detail returns the value of the "detail" field in the mutation. +func (m *PaymentAuditLogMutation) Detail() (r string, exists bool) { + v := m.detail if v == nil { return } return *v, true } -// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldDetail returns the old "detail" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) { +func (m *PaymentAuditLogMutation) OldDetail(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations") + return v, errors.New("OldDetail is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation") + return v, errors.New("OldDetail requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err) + return v, fmt.Errorf("querying old value for OldDetail: %w", err) } - return oldValue.RefundRequestedAt, nil -} - -// ClearRefundRequestedAt clears the value of the "refund_requested_at" field. -func (m *PaymentOrderMutation) ClearRefundRequestedAt() { - m.refund_requested_at = nil - m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{} -} - -// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt] - return ok + return oldValue.Detail, nil } -// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field. -func (m *PaymentOrderMutation) ResetRefundRequestedAt() { - m.refund_requested_at = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestedAt) +// ResetDetail resets all changes to the "detail" field. +func (m *PaymentAuditLogMutation) ResetDetail() { + m.detail = nil } -// SetRefundRequestReason sets the "refund_request_reason" field. -func (m *PaymentOrderMutation) SetRefundRequestReason(s string) { - m.refund_request_reason = &s +// SetOperator sets the "operator" field. +func (m *PaymentAuditLogMutation) SetOperator(s string) { + m.operator = &s } -// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) { - v := m.refund_request_reason +// Operator returns the value of the "operator" field in the mutation. +func (m *PaymentAuditLogMutation) Operator() (r string, exists bool) { + v := m.operator if v == nil { return } return *v, true } -// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldOperator returns the old "operator" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) { +func (m *PaymentAuditLogMutation) OldOperator(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations") + return v, errors.New("OldOperator is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestReason requires an ID field in the mutation") + return v, errors.New("OldOperator requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err) + return v, fmt.Errorf("querying old value for OldOperator: %w", err) } - return oldValue.RefundRequestReason, nil + return oldValue.Operator, nil } -// ClearRefundRequestReason clears the value of the "refund_request_reason" field. -func (m *PaymentOrderMutation) ClearRefundRequestReason() { - m.refund_request_reason = nil - m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{} -} - -// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason] - return ok -} - -// ResetRefundRequestReason resets all changes to the "refund_request_reason" field. -func (m *PaymentOrderMutation) ResetRefundRequestReason() { - m.refund_request_reason = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestReason) +// ResetOperator resets all changes to the "operator" field. +func (m *PaymentAuditLogMutation) ResetOperator() { + m.operator = nil } -// SetRefundRequestedBy sets the "refund_requested_by" field. -func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) { - m.refund_requested_by = &s +// SetCreatedAt sets the "created_at" field. +func (m *PaymentAuditLogMutation) SetCreatedAt(t time.Time) { + m.created_at = &t } -// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation. -func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) { - v := m.refund_requested_by +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentAuditLogMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at if v == nil { return } return *v, true } -// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// OldCreatedAt returns the old "created_at" field's value of the PaymentAuditLog entity. +// If the PaymentAuditLog object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) { +func (m *PaymentAuditLogMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations") + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation") + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err) + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - return oldValue.RefundRequestedBy, nil + return oldValue.CreatedAt, nil } -// ClearRefundRequestedBy clears the value of the "refund_requested_by" field. -func (m *PaymentOrderMutation) ClearRefundRequestedBy() { - m.refund_requested_by = nil - m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{} +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentAuditLogMutation) ResetCreatedAt() { + m.created_at = nil } -// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation. -func (m *PaymentOrderMutation) RefundRequestedByCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy] - return ok +// Where appends a list predicates to the PaymentAuditLogMutation builder. +func (m *PaymentAuditLogMutation) Where(ps ...predicate.PaymentAuditLog) { + m.predicates = append(m.predicates, ps...) } -// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field. -func (m *PaymentOrderMutation) ResetRefundRequestedBy() { - m.refund_requested_by = nil - delete(m.clearedFields, paymentorder.FieldRefundRequestedBy) +// WhereP appends storage-level predicates to the PaymentAuditLogMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentAuditLogMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentAuditLog, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) } -// SetExpiresAt sets the "expires_at" field. -func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) { - m.expires_at = &t +// Op returns the operation name. +func (m *PaymentAuditLogMutation) Op() Op { + return m.op } -// ExpiresAt returns the value of the "expires_at" field in the mutation. -func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) { - v := m.expires_at - if v == nil { - return - } - return *v, true +// SetOp allows setting the mutation operation. +func (m *PaymentAuditLogMutation) SetOp(op Op) { + m.op = op } -// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") +// Type returns the node type of this mutation (PaymentAuditLog). +func (m *PaymentAuditLogMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentAuditLogMutation) Fields() []string { + fields := make([]string, 0, 5) + if m.order_id != nil { + fields = append(fields, paymentauditlog.FieldOrderID) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldExpiresAt requires an ID field in the mutation") + if m.action != nil { + fields = append(fields, paymentauditlog.FieldAction) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + if m.detail != nil { + fields = append(fields, paymentauditlog.FieldDetail) } - return oldValue.ExpiresAt, nil -} - -// ResetExpiresAt resets all changes to the "expires_at" field. -func (m *PaymentOrderMutation) ResetExpiresAt() { - m.expires_at = nil + if m.operator != nil { + fields = append(fields, paymentauditlog.FieldOperator) + } + if m.created_at != nil { + fields = append(fields, paymentauditlog.FieldCreatedAt) + } + return fields } -// SetPaidAt sets the "paid_at" field. -func (m *PaymentOrderMutation) SetPaidAt(t time.Time) { - m.paid_at = &t +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentAuditLogMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentauditlog.FieldOrderID: + return m.OrderID() + case paymentauditlog.FieldAction: + return m.Action() + case paymentauditlog.FieldDetail: + return m.Detail() + case paymentauditlog.FieldOperator: + return m.Operator() + case paymentauditlog.FieldCreatedAt: + return m.CreatedAt() + } + return nil, false } -// PaidAt returns the value of the "paid_at" field in the mutation. -func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) { - v := m.paid_at - if v == nil { - return +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentAuditLogMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case paymentauditlog.FieldOrderID: + return m.OldOrderID(ctx) + case paymentauditlog.FieldAction: + return m.OldAction(ctx) + case paymentauditlog.FieldDetail: + return m.OldDetail(ctx) + case paymentauditlog.FieldOperator: + return m.OldOperator(ctx) + case paymentauditlog.FieldCreatedAt: + return m.OldCreatedAt(ctx) } - return *v, true + return nil, fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaidAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaidAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPaidAt: %w", err) +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentAuditLogMutation) SetField(name string, value ent.Value) error { + switch name { + case paymentauditlog.FieldOrderID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOrderID(v) + return nil + case paymentauditlog.FieldAction: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAction(v) + return nil + case paymentauditlog.FieldDetail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDetail(v) + return nil + case paymentauditlog.FieldOperator: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOperator(v) + return nil + case paymentauditlog.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil } - return oldValue.PaidAt, nil + return fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// ClearPaidAt clears the value of the "paid_at" field. -func (m *PaymentOrderMutation) ClearPaidAt() { - m.paid_at = nil - m.clearedFields[paymentorder.FieldPaidAt] = struct{}{} +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentAuditLogMutation) AddedFields() []string { + return nil } -// PaidAtCleared returns if the "paid_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) PaidAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldPaidAt] - return ok +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentAuditLogMutation) AddedField(name string) (ent.Value, bool) { + return nil, false } -// ResetPaidAt resets all changes to the "paid_at" field. -func (m *PaymentOrderMutation) ResetPaidAt() { - m.paid_at = nil - delete(m.clearedFields, paymentorder.FieldPaidAt) +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentAuditLogMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown PaymentAuditLog numeric field %s", name) } -// SetCompletedAt sets the "completed_at" field. -func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) { - m.completed_at = &t +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentAuditLogMutation) ClearedFields() []string { + return nil } -// CompletedAt returns the value of the "completed_at" field in the mutation. -func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) { - v := m.completed_at - if v == nil { - return - } - return *v, true -} - -// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCompletedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err) - } - return oldValue.CompletedAt, nil -} - -// ClearCompletedAt clears the value of the "completed_at" field. -func (m *PaymentOrderMutation) ClearCompletedAt() { - m.completed_at = nil - m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{} -} - -// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) CompletedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldCompletedAt] +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentAuditLogMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] return ok } -// ResetCompletedAt resets all changes to the "completed_at" field. -func (m *PaymentOrderMutation) ResetCompletedAt() { - m.completed_at = nil - delete(m.clearedFields, paymentorder.FieldCompletedAt) -} - -// SetFailedAt sets the "failed_at" field. -func (m *PaymentOrderMutation) SetFailedAt(t time.Time) { - m.failed_at = &t +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentAuditLogMutation) ClearField(name string) error { + return fmt.Errorf("unknown PaymentAuditLog nullable field %s", name) } -// FailedAt returns the value of the "failed_at" field in the mutation. -func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) { - v := m.failed_at - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentAuditLogMutation) ResetField(name string) error { + switch name { + case paymentauditlog.FieldOrderID: + m.ResetOrderID() + return nil + case paymentauditlog.FieldAction: + m.ResetAction() + return nil + case paymentauditlog.FieldDetail: + m.ResetDetail() + return nil + case paymentauditlog.FieldOperator: + m.ResetOperator() + return nil + case paymentauditlog.FieldCreatedAt: + m.ResetCreatedAt() + return nil } - return *v, true + return fmt.Errorf("unknown PaymentAuditLog field %s", name) } -// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFailedAt is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFailedAt requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFailedAt: %w", err) - } - return oldValue.FailedAt, nil +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PaymentAuditLogMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// ClearFailedAt clears the value of the "failed_at" field. -func (m *PaymentOrderMutation) ClearFailedAt() { - m.failed_at = nil - m.clearedFields[paymentorder.FieldFailedAt] = struct{}{} +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PaymentAuditLogMutation) AddedIDs(name string) []ent.Value { + return nil } -// FailedAtCleared returns if the "failed_at" field was cleared in this mutation. -func (m *PaymentOrderMutation) FailedAtCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldFailedAt] - return ok +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PaymentAuditLogMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// ResetFailedAt resets all changes to the "failed_at" field. -func (m *PaymentOrderMutation) ResetFailedAt() { - m.failed_at = nil - delete(m.clearedFields, paymentorder.FieldFailedAt) +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PaymentAuditLogMutation) RemovedIDs(name string) []ent.Value { + return nil } -// SetFailedReason sets the "failed_reason" field. -func (m *PaymentOrderMutation) SetFailedReason(s string) { - m.failed_reason = &s +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PaymentAuditLogMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges } -// FailedReason returns the value of the "failed_reason" field in the mutation. -func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) { - v := m.failed_reason - if v == nil { - return - } - return *v, true +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PaymentAuditLogMutation) EdgeCleared(name string) bool { + return false } -// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldFailedReason is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldFailedReason requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldFailedReason: %w", err) - } - return oldValue.FailedReason, nil +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentAuditLogMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown PaymentAuditLog unique edge %s", name) } -// ClearFailedReason clears the value of the "failed_reason" field. -func (m *PaymentOrderMutation) ClearFailedReason() { - m.failed_reason = nil - m.clearedFields[paymentorder.FieldFailedReason] = struct{}{} +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentAuditLogMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown PaymentAuditLog edge %s", name) } -// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation. -func (m *PaymentOrderMutation) FailedReasonCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldFailedReason] - return ok +// PaymentOrderMutation represents an operation that mutates the PaymentOrder nodes in the graph. +type PaymentOrderMutation struct { + config + op Op + typ string + id *int64 + user_email *string + user_name *string + user_notes *string + amount *float64 + addamount *float64 + pay_amount *float64 + addpay_amount *float64 + fee_rate *float64 + addfee_rate *float64 + recharge_code *string + out_trade_no *string + payment_type *string + payment_trade_no *string + pay_url *string + qr_code *string + qr_code_img *string + order_type *string + plan_id *int64 + addplan_id *int64 + subscription_group_id *int64 + addsubscription_group_id *int64 + subscription_days *int + addsubscription_days *int + provider_instance_id *string + status *string + refund_amount *float64 + addrefund_amount *float64 + refund_reason *string + refund_at *time.Time + force_refund *bool + refund_requested_at *time.Time + refund_request_reason *string + refund_requested_by *string + expires_at *time.Time + paid_at *time.Time + completed_at *time.Time + failed_at *time.Time + failed_reason *string + client_ip *string + src_host *string + src_url *string + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + user *int64 + cleareduser bool + done bool + oldValue func(context.Context) (*PaymentOrder, error) + predicates []predicate.PaymentOrder } -// ResetFailedReason resets all changes to the "failed_reason" field. -func (m *PaymentOrderMutation) ResetFailedReason() { - m.failed_reason = nil - delete(m.clearedFields, paymentorder.FieldFailedReason) -} +var _ ent.Mutation = (*PaymentOrderMutation)(nil) -// SetClientIP sets the "client_ip" field. -func (m *PaymentOrderMutation) SetClientIP(s string) { - m.client_ip = &s -} +// paymentorderOption allows management of the mutation configuration using functional options. +type paymentorderOption func(*PaymentOrderMutation) -// ClientIP returns the value of the "client_ip" field in the mutation. -func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) { - v := m.client_ip - if v == nil { - return +// newPaymentOrderMutation creates new mutation for the PaymentOrder entity. +func newPaymentOrderMutation(c config, op Op, opts ...paymentorderOption) *PaymentOrderMutation { + m := &PaymentOrderMutation{ + config: c, + op: op, + typ: TypePaymentOrder, + clearedFields: make(map[string]struct{}), } - return *v, true + for _, opt := range opts { + opt(m) + } + return m } -// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity. -// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldClientIP is only allowed on UpdateOne operations") +// withPaymentOrderID sets the ID field of the mutation. +func withPaymentOrderID(id int64) paymentorderOption { + return func(m *PaymentOrderMutation) { + var ( + err error + once sync.Once + value *PaymentOrder + ) + m.oldValue = func(ctx context.Context) (*PaymentOrder, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentOrder.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldClientIP requires an ID field in the mutation") +} + +// withPaymentOrder sets the old PaymentOrder of the mutation. +func withPaymentOrder(node *PaymentOrder) paymentorderOption { + return func(m *PaymentOrderMutation) { + m.oldValue = func(context.Context) (*PaymentOrder, error) { + return node, nil + } + m.id = &node.ID } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldClientIP: %w", err) +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentOrderMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentOrderMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") } - return oldValue.ClientIP, nil + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// ResetClientIP resets all changes to the "client_ip" field. -func (m *PaymentOrderMutation) ResetClientIP() { - m.client_ip = nil +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentOrderMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true } -// SetSrcHost sets the "src_host" field. -func (m *PaymentOrderMutation) SetSrcHost(s string) { - m.src_host = &s +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentOrderMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentOrder.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } } -// SrcHost returns the value of the "src_host" field in the mutation. -func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) { - v := m.src_host +// SetUserID sets the "user_id" field. +func (m *PaymentOrderMutation) SetUserID(i int64) { + m.user = &i +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *PaymentOrderMutation) UserID() (r int64, exists bool) { + v := m.user if v == nil { return } return *v, true } -// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity. +// OldUserID returns the old "user_id" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) { +func (m *PaymentOrderMutation) OldUserID(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSrcHost is only allowed on UpdateOne operations") + return v, errors.New("OldUserID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSrcHost requires an ID field in the mutation") + return v, errors.New("OldUserID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSrcHost: %w", err) + return v, fmt.Errorf("querying old value for OldUserID: %w", err) } - return oldValue.SrcHost, nil + return oldValue.UserID, nil } -// ResetSrcHost resets all changes to the "src_host" field. -func (m *PaymentOrderMutation) ResetSrcHost() { - m.src_host = nil +// ResetUserID resets all changes to the "user_id" field. +func (m *PaymentOrderMutation) ResetUserID() { + m.user = nil } -// SetSrcURL sets the "src_url" field. -func (m *PaymentOrderMutation) SetSrcURL(s string) { - m.src_url = &s +// SetUserEmail sets the "user_email" field. +func (m *PaymentOrderMutation) SetUserEmail(s string) { + m.user_email = &s } -// SrcURL returns the value of the "src_url" field in the mutation. -func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) { - v := m.src_url +// UserEmail returns the value of the "user_email" field in the mutation. +func (m *PaymentOrderMutation) UserEmail() (r string, exists bool) { + v := m.user_email if v == nil { return } return *v, true } -// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity. +// OldUserEmail returns the old "user_email" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) { +func (m *PaymentOrderMutation) OldUserEmail(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSrcURL is only allowed on UpdateOne operations") + return v, errors.New("OldUserEmail is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSrcURL requires an ID field in the mutation") + return v, errors.New("OldUserEmail requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSrcURL: %w", err) + return v, fmt.Errorf("querying old value for OldUserEmail: %w", err) } - return oldValue.SrcURL, nil -} - -// ClearSrcURL clears the value of the "src_url" field. -func (m *PaymentOrderMutation) ClearSrcURL() { - m.src_url = nil - m.clearedFields[paymentorder.FieldSrcURL] = struct{}{} -} - -// SrcURLCleared returns if the "src_url" field was cleared in this mutation. -func (m *PaymentOrderMutation) SrcURLCleared() bool { - _, ok := m.clearedFields[paymentorder.FieldSrcURL] - return ok + return oldValue.UserEmail, nil } -// ResetSrcURL resets all changes to the "src_url" field. -func (m *PaymentOrderMutation) ResetSrcURL() { - m.src_url = nil - delete(m.clearedFields, paymentorder.FieldSrcURL) +// ResetUserEmail resets all changes to the "user_email" field. +func (m *PaymentOrderMutation) ResetUserEmail() { + m.user_email = nil } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// SetUserName sets the "user_name" field. +func (m *PaymentOrderMutation) SetUserName(s string) { + m.user_name = &s } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// UserName returns the value of the "user_name" field in the mutation. +func (m *PaymentOrderMutation) UserName() (r string, exists bool) { + v := m.user_name if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity. +// OldUserName returns the old "user_name" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PaymentOrderMutation) OldUserName(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldUserName is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldUserName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldUserName: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.UserName, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentOrderMutation) ResetCreatedAt() { - m.created_at = nil +// ResetUserName resets all changes to the "user_name" field. +func (m *PaymentOrderMutation) ResetUserName() { + m.user_name = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetUserNotes sets the "user_notes" field. +func (m *PaymentOrderMutation) SetUserNotes(s string) { + m.user_notes = &s } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// UserNotes returns the value of the "user_notes" field in the mutation. +func (m *PaymentOrderMutation) UserNotes() (r string, exists bool) { + v := m.user_notes if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity. +// OldUserNotes returns the old "user_notes" field's value of the PaymentOrder entity. // If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PaymentOrderMutation) OldUserNotes(ctx context.Context) (v *string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldUserNotes is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldUserNotes requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldUserNotes: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.UserNotes, nil } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *PaymentOrderMutation) ResetUpdatedAt() { - m.updated_at = nil +// ClearUserNotes clears the value of the "user_notes" field. +func (m *PaymentOrderMutation) ClearUserNotes() { + m.user_notes = nil + m.clearedFields[paymentorder.FieldUserNotes] = struct{}{} } -// ClearUser clears the "user" edge to the User entity. -func (m *PaymentOrderMutation) ClearUser() { - m.cleareduser = true - m.clearedFields[paymentorder.FieldUserID] = struct{}{} +// UserNotesCleared returns if the "user_notes" field was cleared in this mutation. +func (m *PaymentOrderMutation) UserNotesCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldUserNotes] + return ok } -// UserCleared reports if the "user" edge to the User entity was cleared. -func (m *PaymentOrderMutation) UserCleared() bool { - return m.cleareduser +// ResetUserNotes resets all changes to the "user_notes" field. +func (m *PaymentOrderMutation) ResetUserNotes() { + m.user_notes = nil + delete(m.clearedFields, paymentorder.FieldUserNotes) } -// UserIDs returns the "user" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// UserID instead. It exists only for internal usage by the builders. -func (m *PaymentOrderMutation) UserIDs() (ids []int64) { - if id := m.user; id != nil { - ids = append(ids, *id) +// SetAmount sets the "amount" field. +func (m *PaymentOrderMutation) SetAmount(f float64) { + m.amount = &f + m.addamount = nil +} + +// Amount returns the value of the "amount" field in the mutation. +func (m *PaymentOrderMutation) Amount() (r float64, exists bool) { + v := m.amount + if v == nil { + return } - return + return *v, true } -// ResetUser resets all changes to the "user" edge. -func (m *PaymentOrderMutation) ResetUser() { - m.user = nil - m.cleareduser = false +// OldAmount returns the old "amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAmount: %w", err) + } + return oldValue.Amount, nil } -// Where appends a list predicates to the PaymentOrderMutation builder. -func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) { - m.predicates = append(m.predicates, ps...) +// AddAmount adds f to the "amount" field. +func (m *PaymentOrderMutation) AddAmount(f float64) { + if m.addamount != nil { + *m.addamount += f + } else { + m.addamount = &f + } } -// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentOrder, len(ps)) - for i := range ps { - p[i] = ps[i] +// AddedAmount returns the value that was added to the "amount" field in this mutation. +func (m *PaymentOrderMutation) AddedAmount() (r float64, exists bool) { + v := m.addamount + if v == nil { + return } - m.Where(p...) + return *v, true } -// Op returns the operation name. -func (m *PaymentOrderMutation) Op() Op { - return m.op +// ResetAmount resets all changes to the "amount" field. +func (m *PaymentOrderMutation) ResetAmount() { + m.amount = nil + m.addamount = nil } -// SetOp allows setting the mutation operation. -func (m *PaymentOrderMutation) SetOp(op Op) { - m.op = op +// SetPayAmount sets the "pay_amount" field. +func (m *PaymentOrderMutation) SetPayAmount(f float64) { + m.pay_amount = &f + m.addpay_amount = nil } -// Type returns the node type of this mutation (PaymentOrder). -func (m *PaymentOrderMutation) Type() string { - return m.typ +// PayAmount returns the value of the "pay_amount" field in the mutation. +func (m *PaymentOrderMutation) PayAmount() (r float64, exists bool) { + v := m.pay_amount + if v == nil { + return + } + return *v, true } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *PaymentOrderMutation) Fields() []string { - fields := make([]string, 0, 37) - if m.user != nil { - fields = append(fields, paymentorder.FieldUserID) - } - if m.user_email != nil { - fields = append(fields, paymentorder.FieldUserEmail) - } - if m.user_name != nil { - fields = append(fields, paymentorder.FieldUserName) +// OldPayAmount returns the old "pay_amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPayAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayAmount is only allowed on UpdateOne operations") } - if m.user_notes != nil { - fields = append(fields, paymentorder.FieldUserNotes) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayAmount requires an ID field in the mutation") } - if m.amount != nil { - fields = append(fields, paymentorder.FieldAmount) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayAmount: %w", err) } - if m.pay_amount != nil { - fields = append(fields, paymentorder.FieldPayAmount) + return oldValue.PayAmount, nil +} + +// AddPayAmount adds f to the "pay_amount" field. +func (m *PaymentOrderMutation) AddPayAmount(f float64) { + if m.addpay_amount != nil { + *m.addpay_amount += f + } else { + m.addpay_amount = &f } - if m.fee_rate != nil { - fields = append(fields, paymentorder.FieldFeeRate) +} + +// AddedPayAmount returns the value that was added to the "pay_amount" field in this mutation. +func (m *PaymentOrderMutation) AddedPayAmount() (r float64, exists bool) { + v := m.addpay_amount + if v == nil { + return } - if m.recharge_code != nil { - fields = append(fields, paymentorder.FieldRechargeCode) + return *v, true +} + +// ResetPayAmount resets all changes to the "pay_amount" field. +func (m *PaymentOrderMutation) ResetPayAmount() { + m.pay_amount = nil + m.addpay_amount = nil +} + +// SetFeeRate sets the "fee_rate" field. +func (m *PaymentOrderMutation) SetFeeRate(f float64) { + m.fee_rate = &f + m.addfee_rate = nil +} + +// FeeRate returns the value of the "fee_rate" field in the mutation. +func (m *PaymentOrderMutation) FeeRate() (r float64, exists bool) { + v := m.fee_rate + if v == nil { + return } - if m.out_trade_no != nil { - fields = append(fields, paymentorder.FieldOutTradeNo) + return *v, true +} + +// OldFeeRate returns the old "fee_rate" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFeeRate(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFeeRate is only allowed on UpdateOne operations") } - if m.payment_type != nil { - fields = append(fields, paymentorder.FieldPaymentType) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFeeRate requires an ID field in the mutation") } - if m.payment_trade_no != nil { - fields = append(fields, paymentorder.FieldPaymentTradeNo) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFeeRate: %w", err) } - if m.pay_url != nil { - fields = append(fields, paymentorder.FieldPayURL) + return oldValue.FeeRate, nil +} + +// AddFeeRate adds f to the "fee_rate" field. +func (m *PaymentOrderMutation) AddFeeRate(f float64) { + if m.addfee_rate != nil { + *m.addfee_rate += f + } else { + m.addfee_rate = &f } - if m.qr_code != nil { - fields = append(fields, paymentorder.FieldQrCode) +} + +// AddedFeeRate returns the value that was added to the "fee_rate" field in this mutation. +func (m *PaymentOrderMutation) AddedFeeRate() (r float64, exists bool) { + v := m.addfee_rate + if v == nil { + return } - if m.qr_code_img != nil { - fields = append(fields, paymentorder.FieldQrCodeImg) + return *v, true +} + +// ResetFeeRate resets all changes to the "fee_rate" field. +func (m *PaymentOrderMutation) ResetFeeRate() { + m.fee_rate = nil + m.addfee_rate = nil +} + +// SetRechargeCode sets the "recharge_code" field. +func (m *PaymentOrderMutation) SetRechargeCode(s string) { + m.recharge_code = &s +} + +// RechargeCode returns the value of the "recharge_code" field in the mutation. +func (m *PaymentOrderMutation) RechargeCode() (r string, exists bool) { + v := m.recharge_code + if v == nil { + return } - if m.order_type != nil { - fields = append(fields, paymentorder.FieldOrderType) + return *v, true +} + +// OldRechargeCode returns the old "recharge_code" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRechargeCode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRechargeCode is only allowed on UpdateOne operations") } - if m.plan_id != nil { - fields = append(fields, paymentorder.FieldPlanID) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRechargeCode requires an ID field in the mutation") } - if m.subscription_group_id != nil { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRechargeCode: %w", err) } - if m.subscription_days != nil { - fields = append(fields, paymentorder.FieldSubscriptionDays) + return oldValue.RechargeCode, nil +} + +// ResetRechargeCode resets all changes to the "recharge_code" field. +func (m *PaymentOrderMutation) ResetRechargeCode() { + m.recharge_code = nil +} + +// SetOutTradeNo sets the "out_trade_no" field. +func (m *PaymentOrderMutation) SetOutTradeNo(s string) { + m.out_trade_no = &s +} + +// OutTradeNo returns the value of the "out_trade_no" field in the mutation. +func (m *PaymentOrderMutation) OutTradeNo() (r string, exists bool) { + v := m.out_trade_no + if v == nil { + return } - if m.provider_instance_id != nil { - fields = append(fields, paymentorder.FieldProviderInstanceID) + return *v, true +} + +// OldOutTradeNo returns the old "out_trade_no" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldOutTradeNo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOutTradeNo is only allowed on UpdateOne operations") } - if m.status != nil { - fields = append(fields, paymentorder.FieldStatus) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOutTradeNo requires an ID field in the mutation") } - if m.refund_amount != nil { - fields = append(fields, paymentorder.FieldRefundAmount) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOutTradeNo: %w", err) } - if m.refund_reason != nil { - fields = append(fields, paymentorder.FieldRefundReason) - } - if m.refund_at != nil { - fields = append(fields, paymentorder.FieldRefundAt) - } - if m.force_refund != nil { - fields = append(fields, paymentorder.FieldForceRefund) - } - if m.refund_requested_at != nil { - fields = append(fields, paymentorder.FieldRefundRequestedAt) + return oldValue.OutTradeNo, nil +} + +// ResetOutTradeNo resets all changes to the "out_trade_no" field. +func (m *PaymentOrderMutation) ResetOutTradeNo() { + m.out_trade_no = nil +} + +// SetPaymentType sets the "payment_type" field. +func (m *PaymentOrderMutation) SetPaymentType(s string) { + m.payment_type = &s +} + +// PaymentType returns the value of the "payment_type" field in the mutation. +func (m *PaymentOrderMutation) PaymentType() (r string, exists bool) { + v := m.payment_type + if v == nil { + return } - if m.refund_request_reason != nil { - fields = append(fields, paymentorder.FieldRefundRequestReason) + return *v, true +} + +// OldPaymentType returns the old "payment_type" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaymentType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentType is only allowed on UpdateOne operations") } - if m.refund_requested_by != nil { - fields = append(fields, paymentorder.FieldRefundRequestedBy) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentType requires an ID field in the mutation") } - if m.expires_at != nil { - fields = append(fields, paymentorder.FieldExpiresAt) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentType: %w", err) } - if m.paid_at != nil { - fields = append(fields, paymentorder.FieldPaidAt) + return oldValue.PaymentType, nil +} + +// ResetPaymentType resets all changes to the "payment_type" field. +func (m *PaymentOrderMutation) ResetPaymentType() { + m.payment_type = nil +} + +// SetPaymentTradeNo sets the "payment_trade_no" field. +func (m *PaymentOrderMutation) SetPaymentTradeNo(s string) { + m.payment_trade_no = &s +} + +// PaymentTradeNo returns the value of the "payment_trade_no" field in the mutation. +func (m *PaymentOrderMutation) PaymentTradeNo() (r string, exists bool) { + v := m.payment_trade_no + if v == nil { + return } - if m.completed_at != nil { - fields = append(fields, paymentorder.FieldCompletedAt) + return *v, true +} + +// OldPaymentTradeNo returns the old "payment_trade_no" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaymentTradeNo(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentTradeNo is only allowed on UpdateOne operations") } - if m.failed_at != nil { - fields = append(fields, paymentorder.FieldFailedAt) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentTradeNo requires an ID field in the mutation") } - if m.failed_reason != nil { - fields = append(fields, paymentorder.FieldFailedReason) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentTradeNo: %w", err) } - if m.client_ip != nil { - fields = append(fields, paymentorder.FieldClientIP) + return oldValue.PaymentTradeNo, nil +} + +// ResetPaymentTradeNo resets all changes to the "payment_trade_no" field. +func (m *PaymentOrderMutation) ResetPaymentTradeNo() { + m.payment_trade_no = nil +} + +// SetPayURL sets the "pay_url" field. +func (m *PaymentOrderMutation) SetPayURL(s string) { + m.pay_url = &s +} + +// PayURL returns the value of the "pay_url" field in the mutation. +func (m *PaymentOrderMutation) PayURL() (r string, exists bool) { + v := m.pay_url + if v == nil { + return } - if m.src_host != nil { - fields = append(fields, paymentorder.FieldSrcHost) + return *v, true +} + +// OldPayURL returns the old "pay_url" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPayURL(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPayURL is only allowed on UpdateOne operations") } - if m.src_url != nil { - fields = append(fields, paymentorder.FieldSrcURL) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPayURL requires an ID field in the mutation") } - if m.created_at != nil { - fields = append(fields, paymentorder.FieldCreatedAt) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPayURL: %w", err) } - if m.updated_at != nil { - fields = append(fields, paymentorder.FieldUpdatedAt) + return oldValue.PayURL, nil +} + +// ClearPayURL clears the value of the "pay_url" field. +func (m *PaymentOrderMutation) ClearPayURL() { + m.pay_url = nil + m.clearedFields[paymentorder.FieldPayURL] = struct{}{} +} + +// PayURLCleared returns if the "pay_url" field was cleared in this mutation. +func (m *PaymentOrderMutation) PayURLCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPayURL] + return ok +} + +// ResetPayURL resets all changes to the "pay_url" field. +func (m *PaymentOrderMutation) ResetPayURL() { + m.pay_url = nil + delete(m.clearedFields, paymentorder.FieldPayURL) +} + +// SetQrCode sets the "qr_code" field. +func (m *PaymentOrderMutation) SetQrCode(s string) { + m.qr_code = &s +} + +// QrCode returns the value of the "qr_code" field in the mutation. +func (m *PaymentOrderMutation) QrCode() (r string, exists bool) { + v := m.qr_code + if v == nil { + return } - return fields + return *v, true } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { - switch name { - case paymentorder.FieldUserID: - return m.UserID() - case paymentorder.FieldUserEmail: - return m.UserEmail() - case paymentorder.FieldUserName: - return m.UserName() - case paymentorder.FieldUserNotes: - return m.UserNotes() - case paymentorder.FieldAmount: - return m.Amount() - case paymentorder.FieldPayAmount: - return m.PayAmount() - case paymentorder.FieldFeeRate: - return m.FeeRate() - case paymentorder.FieldRechargeCode: - return m.RechargeCode() - case paymentorder.FieldOutTradeNo: - return m.OutTradeNo() - case paymentorder.FieldPaymentType: - return m.PaymentType() - case paymentorder.FieldPaymentTradeNo: - return m.PaymentTradeNo() - case paymentorder.FieldPayURL: - return m.PayURL() - case paymentorder.FieldQrCode: - return m.QrCode() - case paymentorder.FieldQrCodeImg: - return m.QrCodeImg() - case paymentorder.FieldOrderType: - return m.OrderType() - case paymentorder.FieldPlanID: - return m.PlanID() - case paymentorder.FieldSubscriptionGroupID: - return m.SubscriptionGroupID() - case paymentorder.FieldSubscriptionDays: - return m.SubscriptionDays() - case paymentorder.FieldProviderInstanceID: - return m.ProviderInstanceID() - case paymentorder.FieldStatus: - return m.Status() - case paymentorder.FieldRefundAmount: - return m.RefundAmount() - case paymentorder.FieldRefundReason: - return m.RefundReason() - case paymentorder.FieldRefundAt: - return m.RefundAt() - case paymentorder.FieldForceRefund: - return m.ForceRefund() - case paymentorder.FieldRefundRequestedAt: - return m.RefundRequestedAt() - case paymentorder.FieldRefundRequestReason: - return m.RefundRequestReason() - case paymentorder.FieldRefundRequestedBy: - return m.RefundRequestedBy() - case paymentorder.FieldExpiresAt: - return m.ExpiresAt() - case paymentorder.FieldPaidAt: - return m.PaidAt() - case paymentorder.FieldCompletedAt: - return m.CompletedAt() - case paymentorder.FieldFailedAt: - return m.FailedAt() - case paymentorder.FieldFailedReason: - return m.FailedReason() - case paymentorder.FieldClientIP: - return m.ClientIP() - case paymentorder.FieldSrcHost: - return m.SrcHost() - case paymentorder.FieldSrcURL: - return m.SrcURL() - case paymentorder.FieldCreatedAt: - return m.CreatedAt() - case paymentorder.FieldUpdatedAt: - return m.UpdatedAt() +// OldQrCode returns the old "qr_code" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldQrCode(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQrCode is only allowed on UpdateOne operations") } - return nil, false + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQrCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQrCode: %w", err) + } + return oldValue.QrCode, nil } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case paymentorder.FieldUserID: - return m.OldUserID(ctx) - case paymentorder.FieldUserEmail: - return m.OldUserEmail(ctx) - case paymentorder.FieldUserName: - return m.OldUserName(ctx) - case paymentorder.FieldUserNotes: - return m.OldUserNotes(ctx) - case paymentorder.FieldAmount: - return m.OldAmount(ctx) - case paymentorder.FieldPayAmount: - return m.OldPayAmount(ctx) - case paymentorder.FieldFeeRate: - return m.OldFeeRate(ctx) - case paymentorder.FieldRechargeCode: - return m.OldRechargeCode(ctx) - case paymentorder.FieldOutTradeNo: - return m.OldOutTradeNo(ctx) - case paymentorder.FieldPaymentType: - return m.OldPaymentType(ctx) - case paymentorder.FieldPaymentTradeNo: - return m.OldPaymentTradeNo(ctx) - case paymentorder.FieldPayURL: - return m.OldPayURL(ctx) - case paymentorder.FieldQrCode: - return m.OldQrCode(ctx) - case paymentorder.FieldQrCodeImg: - return m.OldQrCodeImg(ctx) - case paymentorder.FieldOrderType: - return m.OldOrderType(ctx) - case paymentorder.FieldPlanID: - return m.OldPlanID(ctx) - case paymentorder.FieldSubscriptionGroupID: - return m.OldSubscriptionGroupID(ctx) - case paymentorder.FieldSubscriptionDays: - return m.OldSubscriptionDays(ctx) - case paymentorder.FieldProviderInstanceID: - return m.OldProviderInstanceID(ctx) - case paymentorder.FieldStatus: - return m.OldStatus(ctx) - case paymentorder.FieldRefundAmount: - return m.OldRefundAmount(ctx) - case paymentorder.FieldRefundReason: - return m.OldRefundReason(ctx) - case paymentorder.FieldRefundAt: - return m.OldRefundAt(ctx) - case paymentorder.FieldForceRefund: - return m.OldForceRefund(ctx) - case paymentorder.FieldRefundRequestedAt: - return m.OldRefundRequestedAt(ctx) - case paymentorder.FieldRefundRequestReason: - return m.OldRefundRequestReason(ctx) - case paymentorder.FieldRefundRequestedBy: - return m.OldRefundRequestedBy(ctx) - case paymentorder.FieldExpiresAt: - return m.OldExpiresAt(ctx) - case paymentorder.FieldPaidAt: - return m.OldPaidAt(ctx) - case paymentorder.FieldCompletedAt: - return m.OldCompletedAt(ctx) - case paymentorder.FieldFailedAt: - return m.OldFailedAt(ctx) - case paymentorder.FieldFailedReason: - return m.OldFailedReason(ctx) - case paymentorder.FieldClientIP: - return m.OldClientIP(ctx) - case paymentorder.FieldSrcHost: - return m.OldSrcHost(ctx) - case paymentorder.FieldSrcURL: - return m.OldSrcURL(ctx) - case paymentorder.FieldCreatedAt: - return m.OldCreatedAt(ctx) - case paymentorder.FieldUpdatedAt: - return m.OldUpdatedAt(ctx) +// ClearQrCode clears the value of the "qr_code" field. +func (m *PaymentOrderMutation) ClearQrCode() { + m.qr_code = nil + m.clearedFields[paymentorder.FieldQrCode] = struct{}{} +} + +// QrCodeCleared returns if the "qr_code" field was cleared in this mutation. +func (m *PaymentOrderMutation) QrCodeCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldQrCode] + return ok +} + +// ResetQrCode resets all changes to the "qr_code" field. +func (m *PaymentOrderMutation) ResetQrCode() { + m.qr_code = nil + delete(m.clearedFields, paymentorder.FieldQrCode) +} + +// SetQrCodeImg sets the "qr_code_img" field. +func (m *PaymentOrderMutation) SetQrCodeImg(s string) { + m.qr_code_img = &s +} + +// QrCodeImg returns the value of the "qr_code_img" field in the mutation. +func (m *PaymentOrderMutation) QrCodeImg() (r string, exists bool) { + v := m.qr_code_img + if v == nil { + return } - return nil, fmt.Errorf("unknown PaymentOrder field %s", name) + return *v, true } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { - switch name { - case paymentorder.FieldUserID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserID(v) - return nil - case paymentorder.FieldUserEmail: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserEmail(v) - return nil - case paymentorder.FieldUserName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserName(v) - return nil - case paymentorder.FieldUserNotes: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUserNotes(v) - return nil - case paymentorder.FieldAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAmount(v) - return nil - case paymentorder.FieldPayAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPayAmount(v) - return nil - case paymentorder.FieldFeeRate: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFeeRate(v) - return nil - case paymentorder.FieldRechargeCode: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRechargeCode(v) - return nil - case paymentorder.FieldOutTradeNo: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOutTradeNo(v) - return nil - case paymentorder.FieldPaymentType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaymentType(v) - return nil - case paymentorder.FieldPaymentTradeNo: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaymentTradeNo(v) - return nil - case paymentorder.FieldPayURL: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPayURL(v) - return nil - case paymentorder.FieldQrCode: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetQrCode(v) - return nil - case paymentorder.FieldQrCodeImg: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetQrCodeImg(v) - return nil - case paymentorder.FieldOrderType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOrderType(v) - return nil - case paymentorder.FieldPlanID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPlanID(v) - return nil - case paymentorder.FieldSubscriptionGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionGroupID(v) - return nil - case paymentorder.FieldSubscriptionDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSubscriptionDays(v) - return nil - case paymentorder.FieldProviderInstanceID: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetProviderInstanceID(v) - return nil - case paymentorder.FieldStatus: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetStatus(v) - return nil - case paymentorder.FieldRefundAmount: - v, ok := value.(float64) +// OldQrCodeImg returns the old "qr_code_img" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldQrCodeImg(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQrCodeImg is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQrCodeImg requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQrCodeImg: %w", err) + } + return oldValue.QrCodeImg, nil +} + +// ClearQrCodeImg clears the value of the "qr_code_img" field. +func (m *PaymentOrderMutation) ClearQrCodeImg() { + m.qr_code_img = nil + m.clearedFields[paymentorder.FieldQrCodeImg] = struct{}{} +} + +// QrCodeImgCleared returns if the "qr_code_img" field was cleared in this mutation. +func (m *PaymentOrderMutation) QrCodeImgCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldQrCodeImg] + return ok +} + +// ResetQrCodeImg resets all changes to the "qr_code_img" field. +func (m *PaymentOrderMutation) ResetQrCodeImg() { + m.qr_code_img = nil + delete(m.clearedFields, paymentorder.FieldQrCodeImg) +} + +// SetOrderType sets the "order_type" field. +func (m *PaymentOrderMutation) SetOrderType(s string) { + m.order_type = &s +} + +// OrderType returns the value of the "order_type" field in the mutation. +func (m *PaymentOrderMutation) OrderType() (r string, exists bool) { + v := m.order_type + if v == nil { + return + } + return *v, true +} + +// OldOrderType returns the old "order_type" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldOrderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldOrderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldOrderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldOrderType: %w", err) + } + return oldValue.OrderType, nil +} + +// ResetOrderType resets all changes to the "order_type" field. +func (m *PaymentOrderMutation) ResetOrderType() { + m.order_type = nil +} + +// SetPlanID sets the "plan_id" field. +func (m *PaymentOrderMutation) SetPlanID(i int64) { + m.plan_id = &i + m.addplan_id = nil +} + +// PlanID returns the value of the "plan_id" field in the mutation. +func (m *PaymentOrderMutation) PlanID() (r int64, exists bool) { + v := m.plan_id + if v == nil { + return + } + return *v, true +} + +// OldPlanID returns the old "plan_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPlanID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlanID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlanID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlanID: %w", err) + } + return oldValue.PlanID, nil +} + +// AddPlanID adds i to the "plan_id" field. +func (m *PaymentOrderMutation) AddPlanID(i int64) { + if m.addplan_id != nil { + *m.addplan_id += i + } else { + m.addplan_id = &i + } +} + +// AddedPlanID returns the value that was added to the "plan_id" field in this mutation. +func (m *PaymentOrderMutation) AddedPlanID() (r int64, exists bool) { + v := m.addplan_id + if v == nil { + return + } + return *v, true +} + +// ClearPlanID clears the value of the "plan_id" field. +func (m *PaymentOrderMutation) ClearPlanID() { + m.plan_id = nil + m.addplan_id = nil + m.clearedFields[paymentorder.FieldPlanID] = struct{}{} +} + +// PlanIDCleared returns if the "plan_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) PlanIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPlanID] + return ok +} + +// ResetPlanID resets all changes to the "plan_id" field. +func (m *PaymentOrderMutation) ResetPlanID() { + m.plan_id = nil + m.addplan_id = nil + delete(m.clearedFields, paymentorder.FieldPlanID) +} + +// SetSubscriptionGroupID sets the "subscription_group_id" field. +func (m *PaymentOrderMutation) SetSubscriptionGroupID(i int64) { + m.subscription_group_id = &i + m.addsubscription_group_id = nil +} + +// SubscriptionGroupID returns the value of the "subscription_group_id" field in the mutation. +func (m *PaymentOrderMutation) SubscriptionGroupID() (r int64, exists bool) { + v := m.subscription_group_id + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionGroupID returns the old "subscription_group_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSubscriptionGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionGroupID: %w", err) + } + return oldValue.SubscriptionGroupID, nil +} + +// AddSubscriptionGroupID adds i to the "subscription_group_id" field. +func (m *PaymentOrderMutation) AddSubscriptionGroupID(i int64) { + if m.addsubscription_group_id != nil { + *m.addsubscription_group_id += i + } else { + m.addsubscription_group_id = &i + } +} + +// AddedSubscriptionGroupID returns the value that was added to the "subscription_group_id" field in this mutation. +func (m *PaymentOrderMutation) AddedSubscriptionGroupID() (r int64, exists bool) { + v := m.addsubscription_group_id + if v == nil { + return + } + return *v, true +} + +// ClearSubscriptionGroupID clears the value of the "subscription_group_id" field. +func (m *PaymentOrderMutation) ClearSubscriptionGroupID() { + m.subscription_group_id = nil + m.addsubscription_group_id = nil + m.clearedFields[paymentorder.FieldSubscriptionGroupID] = struct{}{} +} + +// SubscriptionGroupIDCleared returns if the "subscription_group_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) SubscriptionGroupIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSubscriptionGroupID] + return ok +} + +// ResetSubscriptionGroupID resets all changes to the "subscription_group_id" field. +func (m *PaymentOrderMutation) ResetSubscriptionGroupID() { + m.subscription_group_id = nil + m.addsubscription_group_id = nil + delete(m.clearedFields, paymentorder.FieldSubscriptionGroupID) +} + +// SetSubscriptionDays sets the "subscription_days" field. +func (m *PaymentOrderMutation) SetSubscriptionDays(i int) { + m.subscription_days = &i + m.addsubscription_days = nil +} + +// SubscriptionDays returns the value of the "subscription_days" field in the mutation. +func (m *PaymentOrderMutation) SubscriptionDays() (r int, exists bool) { + v := m.subscription_days + if v == nil { + return + } + return *v, true +} + +// OldSubscriptionDays returns the old "subscription_days" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSubscriptionDays(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSubscriptionDays is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSubscriptionDays requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSubscriptionDays: %w", err) + } + return oldValue.SubscriptionDays, nil +} + +// AddSubscriptionDays adds i to the "subscription_days" field. +func (m *PaymentOrderMutation) AddSubscriptionDays(i int) { + if m.addsubscription_days != nil { + *m.addsubscription_days += i + } else { + m.addsubscription_days = &i + } +} + +// AddedSubscriptionDays returns the value that was added to the "subscription_days" field in this mutation. +func (m *PaymentOrderMutation) AddedSubscriptionDays() (r int, exists bool) { + v := m.addsubscription_days + if v == nil { + return + } + return *v, true +} + +// ClearSubscriptionDays clears the value of the "subscription_days" field. +func (m *PaymentOrderMutation) ClearSubscriptionDays() { + m.subscription_days = nil + m.addsubscription_days = nil + m.clearedFields[paymentorder.FieldSubscriptionDays] = struct{}{} +} + +// SubscriptionDaysCleared returns if the "subscription_days" field was cleared in this mutation. +func (m *PaymentOrderMutation) SubscriptionDaysCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSubscriptionDays] + return ok +} + +// ResetSubscriptionDays resets all changes to the "subscription_days" field. +func (m *PaymentOrderMutation) ResetSubscriptionDays() { + m.subscription_days = nil + m.addsubscription_days = nil + delete(m.clearedFields, paymentorder.FieldSubscriptionDays) +} + +// SetProviderInstanceID sets the "provider_instance_id" field. +func (m *PaymentOrderMutation) SetProviderInstanceID(s string) { + m.provider_instance_id = &s +} + +// ProviderInstanceID returns the value of the "provider_instance_id" field in the mutation. +func (m *PaymentOrderMutation) ProviderInstanceID() (r string, exists bool) { + v := m.provider_instance_id + if v == nil { + return + } + return *v, true +} + +// OldProviderInstanceID returns the old "provider_instance_id" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderInstanceID(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderInstanceID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderInstanceID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderInstanceID: %w", err) + } + return oldValue.ProviderInstanceID, nil +} + +// ClearProviderInstanceID clears the value of the "provider_instance_id" field. +func (m *PaymentOrderMutation) ClearProviderInstanceID() { + m.provider_instance_id = nil + m.clearedFields[paymentorder.FieldProviderInstanceID] = struct{}{} +} + +// ProviderInstanceIDCleared returns if the "provider_instance_id" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderInstanceIDCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderInstanceID] + return ok +} + +// ResetProviderInstanceID resets all changes to the "provider_instance_id" field. +func (m *PaymentOrderMutation) ResetProviderInstanceID() { + m.provider_instance_id = nil + delete(m.clearedFields, paymentorder.FieldProviderInstanceID) +} + +// SetStatus sets the "status" field. +func (m *PaymentOrderMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *PaymentOrderMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *PaymentOrderMutation) ResetStatus() { + m.status = nil +} + +// SetRefundAmount sets the "refund_amount" field. +func (m *PaymentOrderMutation) SetRefundAmount(f float64) { + m.refund_amount = &f + m.addrefund_amount = nil +} + +// RefundAmount returns the value of the "refund_amount" field in the mutation. +func (m *PaymentOrderMutation) RefundAmount() (r float64, exists bool) { + v := m.refund_amount + if v == nil { + return + } + return *v, true +} + +// OldRefundAmount returns the old "refund_amount" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundAmount(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundAmount is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundAmount requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundAmount: %w", err) + } + return oldValue.RefundAmount, nil +} + +// AddRefundAmount adds f to the "refund_amount" field. +func (m *PaymentOrderMutation) AddRefundAmount(f float64) { + if m.addrefund_amount != nil { + *m.addrefund_amount += f + } else { + m.addrefund_amount = &f + } +} + +// AddedRefundAmount returns the value that was added to the "refund_amount" field in this mutation. +func (m *PaymentOrderMutation) AddedRefundAmount() (r float64, exists bool) { + v := m.addrefund_amount + if v == nil { + return + } + return *v, true +} + +// ResetRefundAmount resets all changes to the "refund_amount" field. +func (m *PaymentOrderMutation) ResetRefundAmount() { + m.refund_amount = nil + m.addrefund_amount = nil +} + +// SetRefundReason sets the "refund_reason" field. +func (m *PaymentOrderMutation) SetRefundReason(s string) { + m.refund_reason = &s +} + +// RefundReason returns the value of the "refund_reason" field in the mutation. +func (m *PaymentOrderMutation) RefundReason() (r string, exists bool) { + v := m.refund_reason + if v == nil { + return + } + return *v, true +} + +// OldRefundReason returns the old "refund_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundReason: %w", err) + } + return oldValue.RefundReason, nil +} + +// ClearRefundReason clears the value of the "refund_reason" field. +func (m *PaymentOrderMutation) ClearRefundReason() { + m.refund_reason = nil + m.clearedFields[paymentorder.FieldRefundReason] = struct{}{} +} + +// RefundReasonCleared returns if the "refund_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundReason] + return ok +} + +// ResetRefundReason resets all changes to the "refund_reason" field. +func (m *PaymentOrderMutation) ResetRefundReason() { + m.refund_reason = nil + delete(m.clearedFields, paymentorder.FieldRefundReason) +} + +// SetRefundAt sets the "refund_at" field. +func (m *PaymentOrderMutation) SetRefundAt(t time.Time) { + m.refund_at = &t +} + +// RefundAt returns the value of the "refund_at" field in the mutation. +func (m *PaymentOrderMutation) RefundAt() (r time.Time, exists bool) { + v := m.refund_at + if v == nil { + return + } + return *v, true +} + +// OldRefundAt returns the old "refund_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundAt: %w", err) + } + return oldValue.RefundAt, nil +} + +// ClearRefundAt clears the value of the "refund_at" field. +func (m *PaymentOrderMutation) ClearRefundAt() { + m.refund_at = nil + m.clearedFields[paymentorder.FieldRefundAt] = struct{}{} +} + +// RefundAtCleared returns if the "refund_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundAt] + return ok +} + +// ResetRefundAt resets all changes to the "refund_at" field. +func (m *PaymentOrderMutation) ResetRefundAt() { + m.refund_at = nil + delete(m.clearedFields, paymentorder.FieldRefundAt) +} + +// SetForceRefund sets the "force_refund" field. +func (m *PaymentOrderMutation) SetForceRefund(b bool) { + m.force_refund = &b +} + +// ForceRefund returns the value of the "force_refund" field in the mutation. +func (m *PaymentOrderMutation) ForceRefund() (r bool, exists bool) { + v := m.force_refund + if v == nil { + return + } + return *v, true +} + +// OldForceRefund returns the old "force_refund" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldForceRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldForceRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldForceRefund requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldForceRefund: %w", err) + } + return oldValue.ForceRefund, nil +} + +// ResetForceRefund resets all changes to the "force_refund" field. +func (m *PaymentOrderMutation) ResetForceRefund() { + m.force_refund = nil +} + +// SetRefundRequestedAt sets the "refund_requested_at" field. +func (m *PaymentOrderMutation) SetRefundRequestedAt(t time.Time) { + m.refund_requested_at = &t +} + +// RefundRequestedAt returns the value of the "refund_requested_at" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestedAt() (r time.Time, exists bool) { + v := m.refund_requested_at + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestedAt returns the old "refund_requested_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestedAt: %w", err) + } + return oldValue.RefundRequestedAt, nil +} + +// ClearRefundRequestedAt clears the value of the "refund_requested_at" field. +func (m *PaymentOrderMutation) ClearRefundRequestedAt() { + m.refund_requested_at = nil + m.clearedFields[paymentorder.FieldRefundRequestedAt] = struct{}{} +} + +// RefundRequestedAtCleared returns if the "refund_requested_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestedAt] + return ok +} + +// ResetRefundRequestedAt resets all changes to the "refund_requested_at" field. +func (m *PaymentOrderMutation) ResetRefundRequestedAt() { + m.refund_requested_at = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestedAt) +} + +// SetRefundRequestReason sets the "refund_request_reason" field. +func (m *PaymentOrderMutation) SetRefundRequestReason(s string) { + m.refund_request_reason = &s +} + +// RefundRequestReason returns the value of the "refund_request_reason" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestReason() (r string, exists bool) { + v := m.refund_request_reason + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestReason returns the old "refund_request_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestReason: %w", err) + } + return oldValue.RefundRequestReason, nil +} + +// ClearRefundRequestReason clears the value of the "refund_request_reason" field. +func (m *PaymentOrderMutation) ClearRefundRequestReason() { + m.refund_request_reason = nil + m.clearedFields[paymentorder.FieldRefundRequestReason] = struct{}{} +} + +// RefundRequestReasonCleared returns if the "refund_request_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestReason] + return ok +} + +// ResetRefundRequestReason resets all changes to the "refund_request_reason" field. +func (m *PaymentOrderMutation) ResetRefundRequestReason() { + m.refund_request_reason = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestReason) +} + +// SetRefundRequestedBy sets the "refund_requested_by" field. +func (m *PaymentOrderMutation) SetRefundRequestedBy(s string) { + m.refund_requested_by = &s +} + +// RefundRequestedBy returns the value of the "refund_requested_by" field in the mutation. +func (m *PaymentOrderMutation) RefundRequestedBy() (r string, exists bool) { + v := m.refund_requested_by + if v == nil { + return + } + return *v, true +} + +// OldRefundRequestedBy returns the old "refund_requested_by" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldRefundRequestedBy(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundRequestedBy is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundRequestedBy requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundRequestedBy: %w", err) + } + return oldValue.RefundRequestedBy, nil +} + +// ClearRefundRequestedBy clears the value of the "refund_requested_by" field. +func (m *PaymentOrderMutation) ClearRefundRequestedBy() { + m.refund_requested_by = nil + m.clearedFields[paymentorder.FieldRefundRequestedBy] = struct{}{} +} + +// RefundRequestedByCleared returns if the "refund_requested_by" field was cleared in this mutation. +func (m *PaymentOrderMutation) RefundRequestedByCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldRefundRequestedBy] + return ok +} + +// ResetRefundRequestedBy resets all changes to the "refund_requested_by" field. +func (m *PaymentOrderMutation) ResetRefundRequestedBy() { + m.refund_requested_by = nil + delete(m.clearedFields, paymentorder.FieldRefundRequestedBy) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PaymentOrderMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PaymentOrderMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PaymentOrderMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// SetPaidAt sets the "paid_at" field. +func (m *PaymentOrderMutation) SetPaidAt(t time.Time) { + m.paid_at = &t +} + +// PaidAt returns the value of the "paid_at" field in the mutation. +func (m *PaymentOrderMutation) PaidAt() (r time.Time, exists bool) { + v := m.paid_at + if v == nil { + return + } + return *v, true +} + +// OldPaidAt returns the old "paid_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldPaidAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaidAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaidAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaidAt: %w", err) + } + return oldValue.PaidAt, nil +} + +// ClearPaidAt clears the value of the "paid_at" field. +func (m *PaymentOrderMutation) ClearPaidAt() { + m.paid_at = nil + m.clearedFields[paymentorder.FieldPaidAt] = struct{}{} +} + +// PaidAtCleared returns if the "paid_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) PaidAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldPaidAt] + return ok +} + +// ResetPaidAt resets all changes to the "paid_at" field. +func (m *PaymentOrderMutation) ResetPaidAt() { + m.paid_at = nil + delete(m.clearedFields, paymentorder.FieldPaidAt) +} + +// SetCompletedAt sets the "completed_at" field. +func (m *PaymentOrderMutation) SetCompletedAt(t time.Time) { + m.completed_at = &t +} + +// CompletedAt returns the value of the "completed_at" field in the mutation. +func (m *PaymentOrderMutation) CompletedAt() (r time.Time, exists bool) { + v := m.completed_at + if v == nil { + return + } + return *v, true +} + +// OldCompletedAt returns the old "completed_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldCompletedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletedAt: %w", err) + } + return oldValue.CompletedAt, nil +} + +// ClearCompletedAt clears the value of the "completed_at" field. +func (m *PaymentOrderMutation) ClearCompletedAt() { + m.completed_at = nil + m.clearedFields[paymentorder.FieldCompletedAt] = struct{}{} +} + +// CompletedAtCleared returns if the "completed_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) CompletedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldCompletedAt] + return ok +} + +// ResetCompletedAt resets all changes to the "completed_at" field. +func (m *PaymentOrderMutation) ResetCompletedAt() { + m.completed_at = nil + delete(m.clearedFields, paymentorder.FieldCompletedAt) +} + +// SetFailedAt sets the "failed_at" field. +func (m *PaymentOrderMutation) SetFailedAt(t time.Time) { + m.failed_at = &t +} + +// FailedAt returns the value of the "failed_at" field in the mutation. +func (m *PaymentOrderMutation) FailedAt() (r time.Time, exists bool) { + v := m.failed_at + if v == nil { + return + } + return *v, true +} + +// OldFailedAt returns the old "failed_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFailedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFailedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFailedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFailedAt: %w", err) + } + return oldValue.FailedAt, nil +} + +// ClearFailedAt clears the value of the "failed_at" field. +func (m *PaymentOrderMutation) ClearFailedAt() { + m.failed_at = nil + m.clearedFields[paymentorder.FieldFailedAt] = struct{}{} +} + +// FailedAtCleared returns if the "failed_at" field was cleared in this mutation. +func (m *PaymentOrderMutation) FailedAtCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldFailedAt] + return ok +} + +// ResetFailedAt resets all changes to the "failed_at" field. +func (m *PaymentOrderMutation) ResetFailedAt() { + m.failed_at = nil + delete(m.clearedFields, paymentorder.FieldFailedAt) +} + +// SetFailedReason sets the "failed_reason" field. +func (m *PaymentOrderMutation) SetFailedReason(s string) { + m.failed_reason = &s +} + +// FailedReason returns the value of the "failed_reason" field in the mutation. +func (m *PaymentOrderMutation) FailedReason() (r string, exists bool) { + v := m.failed_reason + if v == nil { + return + } + return *v, true +} + +// OldFailedReason returns the old "failed_reason" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldFailedReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFailedReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFailedReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFailedReason: %w", err) + } + return oldValue.FailedReason, nil +} + +// ClearFailedReason clears the value of the "failed_reason" field. +func (m *PaymentOrderMutation) ClearFailedReason() { + m.failed_reason = nil + m.clearedFields[paymentorder.FieldFailedReason] = struct{}{} +} + +// FailedReasonCleared returns if the "failed_reason" field was cleared in this mutation. +func (m *PaymentOrderMutation) FailedReasonCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldFailedReason] + return ok +} + +// ResetFailedReason resets all changes to the "failed_reason" field. +func (m *PaymentOrderMutation) ResetFailedReason() { + m.failed_reason = nil + delete(m.clearedFields, paymentorder.FieldFailedReason) +} + +// SetClientIP sets the "client_ip" field. +func (m *PaymentOrderMutation) SetClientIP(s string) { + m.client_ip = &s +} + +// ClientIP returns the value of the "client_ip" field in the mutation. +func (m *PaymentOrderMutation) ClientIP() (r string, exists bool) { + v := m.client_ip + if v == nil { + return + } + return *v, true +} + +// OldClientIP returns the old "client_ip" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldClientIP(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClientIP is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClientIP requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClientIP: %w", err) + } + return oldValue.ClientIP, nil +} + +// ResetClientIP resets all changes to the "client_ip" field. +func (m *PaymentOrderMutation) ResetClientIP() { + m.client_ip = nil +} + +// SetSrcHost sets the "src_host" field. +func (m *PaymentOrderMutation) SetSrcHost(s string) { + m.src_host = &s +} + +// SrcHost returns the value of the "src_host" field in the mutation. +func (m *PaymentOrderMutation) SrcHost() (r string, exists bool) { + v := m.src_host + if v == nil { + return + } + return *v, true +} + +// OldSrcHost returns the old "src_host" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSrcHost(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSrcHost is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSrcHost requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSrcHost: %w", err) + } + return oldValue.SrcHost, nil +} + +// ResetSrcHost resets all changes to the "src_host" field. +func (m *PaymentOrderMutation) ResetSrcHost() { + m.src_host = nil +} + +// SetSrcURL sets the "src_url" field. +func (m *PaymentOrderMutation) SetSrcURL(s string) { + m.src_url = &s +} + +// SrcURL returns the value of the "src_url" field in the mutation. +func (m *PaymentOrderMutation) SrcURL() (r string, exists bool) { + v := m.src_url + if v == nil { + return + } + return *v, true +} + +// OldSrcURL returns the old "src_url" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldSrcURL(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSrcURL is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSrcURL requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSrcURL: %w", err) + } + return oldValue.SrcURL, nil +} + +// ClearSrcURL clears the value of the "src_url" field. +func (m *PaymentOrderMutation) ClearSrcURL() { + m.src_url = nil + m.clearedFields[paymentorder.FieldSrcURL] = struct{}{} +} + +// SrcURLCleared returns if the "src_url" field was cleared in this mutation. +func (m *PaymentOrderMutation) SrcURLCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldSrcURL] + return ok +} + +// ResetSrcURL resets all changes to the "src_url" field. +func (m *PaymentOrderMutation) ResetSrcURL() { + m.src_url = nil + delete(m.clearedFields, paymentorder.FieldSrcURL) +} + +// SetCreatedAt sets the "created_at" field. +func (m *PaymentOrderMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentOrderMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentOrderMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PaymentOrderMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PaymentOrderMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PaymentOrderMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// ClearUser clears the "user" edge to the User entity. +func (m *PaymentOrderMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[paymentorder.FieldUserID] = struct{}{} +} + +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *PaymentOrderMutation) UserCleared() bool { + return m.cleareduser +} + +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *PaymentOrderMutation) UserIDs() (ids []int64) { + if id := m.user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetUser resets all changes to the "user" edge. +func (m *PaymentOrderMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// Where appends a list predicates to the PaymentOrderMutation builder. +func (m *PaymentOrderMutation) Where(ps ...predicate.PaymentOrder) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PaymentOrderMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentOrderMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentOrder, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PaymentOrderMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PaymentOrderMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PaymentOrder). +func (m *PaymentOrderMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentOrderMutation) Fields() []string { + fields := make([]string, 0, 37) + if m.user != nil { + fields = append(fields, paymentorder.FieldUserID) + } + if m.user_email != nil { + fields = append(fields, paymentorder.FieldUserEmail) + } + if m.user_name != nil { + fields = append(fields, paymentorder.FieldUserName) + } + if m.user_notes != nil { + fields = append(fields, paymentorder.FieldUserNotes) + } + if m.amount != nil { + fields = append(fields, paymentorder.FieldAmount) + } + if m.pay_amount != nil { + fields = append(fields, paymentorder.FieldPayAmount) + } + if m.fee_rate != nil { + fields = append(fields, paymentorder.FieldFeeRate) + } + if m.recharge_code != nil { + fields = append(fields, paymentorder.FieldRechargeCode) + } + if m.out_trade_no != nil { + fields = append(fields, paymentorder.FieldOutTradeNo) + } + if m.payment_type != nil { + fields = append(fields, paymentorder.FieldPaymentType) + } + if m.payment_trade_no != nil { + fields = append(fields, paymentorder.FieldPaymentTradeNo) + } + if m.pay_url != nil { + fields = append(fields, paymentorder.FieldPayURL) + } + if m.qr_code != nil { + fields = append(fields, paymentorder.FieldQrCode) + } + if m.qr_code_img != nil { + fields = append(fields, paymentorder.FieldQrCodeImg) + } + if m.order_type != nil { + fields = append(fields, paymentorder.FieldOrderType) + } + if m.plan_id != nil { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.subscription_group_id != nil { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.subscription_days != nil { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.provider_instance_id != nil { + fields = append(fields, paymentorder.FieldProviderInstanceID) + } + if m.status != nil { + fields = append(fields, paymentorder.FieldStatus) + } + if m.refund_amount != nil { + fields = append(fields, paymentorder.FieldRefundAmount) + } + if m.refund_reason != nil { + fields = append(fields, paymentorder.FieldRefundReason) + } + if m.refund_at != nil { + fields = append(fields, paymentorder.FieldRefundAt) + } + if m.force_refund != nil { + fields = append(fields, paymentorder.FieldForceRefund) + } + if m.refund_requested_at != nil { + fields = append(fields, paymentorder.FieldRefundRequestedAt) + } + if m.refund_request_reason != nil { + fields = append(fields, paymentorder.FieldRefundRequestReason) + } + if m.refund_requested_by != nil { + fields = append(fields, paymentorder.FieldRefundRequestedBy) + } + if m.expires_at != nil { + fields = append(fields, paymentorder.FieldExpiresAt) + } + if m.paid_at != nil { + fields = append(fields, paymentorder.FieldPaidAt) + } + if m.completed_at != nil { + fields = append(fields, paymentorder.FieldCompletedAt) + } + if m.failed_at != nil { + fields = append(fields, paymentorder.FieldFailedAt) + } + if m.failed_reason != nil { + fields = append(fields, paymentorder.FieldFailedReason) + } + if m.client_ip != nil { + fields = append(fields, paymentorder.FieldClientIP) + } + if m.src_host != nil { + fields = append(fields, paymentorder.FieldSrcHost) + } + if m.src_url != nil { + fields = append(fields, paymentorder.FieldSrcURL) + } + if m.created_at != nil { + fields = append(fields, paymentorder.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, paymentorder.FieldUpdatedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentorder.FieldUserID: + return m.UserID() + case paymentorder.FieldUserEmail: + return m.UserEmail() + case paymentorder.FieldUserName: + return m.UserName() + case paymentorder.FieldUserNotes: + return m.UserNotes() + case paymentorder.FieldAmount: + return m.Amount() + case paymentorder.FieldPayAmount: + return m.PayAmount() + case paymentorder.FieldFeeRate: + return m.FeeRate() + case paymentorder.FieldRechargeCode: + return m.RechargeCode() + case paymentorder.FieldOutTradeNo: + return m.OutTradeNo() + case paymentorder.FieldPaymentType: + return m.PaymentType() + case paymentorder.FieldPaymentTradeNo: + return m.PaymentTradeNo() + case paymentorder.FieldPayURL: + return m.PayURL() + case paymentorder.FieldQrCode: + return m.QrCode() + case paymentorder.FieldQrCodeImg: + return m.QrCodeImg() + case paymentorder.FieldOrderType: + return m.OrderType() + case paymentorder.FieldPlanID: + return m.PlanID() + case paymentorder.FieldSubscriptionGroupID: + return m.SubscriptionGroupID() + case paymentorder.FieldSubscriptionDays: + return m.SubscriptionDays() + case paymentorder.FieldProviderInstanceID: + return m.ProviderInstanceID() + case paymentorder.FieldStatus: + return m.Status() + case paymentorder.FieldRefundAmount: + return m.RefundAmount() + case paymentorder.FieldRefundReason: + return m.RefundReason() + case paymentorder.FieldRefundAt: + return m.RefundAt() + case paymentorder.FieldForceRefund: + return m.ForceRefund() + case paymentorder.FieldRefundRequestedAt: + return m.RefundRequestedAt() + case paymentorder.FieldRefundRequestReason: + return m.RefundRequestReason() + case paymentorder.FieldRefundRequestedBy: + return m.RefundRequestedBy() + case paymentorder.FieldExpiresAt: + return m.ExpiresAt() + case paymentorder.FieldPaidAt: + return m.PaidAt() + case paymentorder.FieldCompletedAt: + return m.CompletedAt() + case paymentorder.FieldFailedAt: + return m.FailedAt() + case paymentorder.FieldFailedReason: + return m.FailedReason() + case paymentorder.FieldClientIP: + return m.ClientIP() + case paymentorder.FieldSrcHost: + return m.SrcHost() + case paymentorder.FieldSrcURL: + return m.SrcURL() + case paymentorder.FieldCreatedAt: + return m.CreatedAt() + case paymentorder.FieldUpdatedAt: + return m.UpdatedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case paymentorder.FieldUserID: + return m.OldUserID(ctx) + case paymentorder.FieldUserEmail: + return m.OldUserEmail(ctx) + case paymentorder.FieldUserName: + return m.OldUserName(ctx) + case paymentorder.FieldUserNotes: + return m.OldUserNotes(ctx) + case paymentorder.FieldAmount: + return m.OldAmount(ctx) + case paymentorder.FieldPayAmount: + return m.OldPayAmount(ctx) + case paymentorder.FieldFeeRate: + return m.OldFeeRate(ctx) + case paymentorder.FieldRechargeCode: + return m.OldRechargeCode(ctx) + case paymentorder.FieldOutTradeNo: + return m.OldOutTradeNo(ctx) + case paymentorder.FieldPaymentType: + return m.OldPaymentType(ctx) + case paymentorder.FieldPaymentTradeNo: + return m.OldPaymentTradeNo(ctx) + case paymentorder.FieldPayURL: + return m.OldPayURL(ctx) + case paymentorder.FieldQrCode: + return m.OldQrCode(ctx) + case paymentorder.FieldQrCodeImg: + return m.OldQrCodeImg(ctx) + case paymentorder.FieldOrderType: + return m.OldOrderType(ctx) + case paymentorder.FieldPlanID: + return m.OldPlanID(ctx) + case paymentorder.FieldSubscriptionGroupID: + return m.OldSubscriptionGroupID(ctx) + case paymentorder.FieldSubscriptionDays: + return m.OldSubscriptionDays(ctx) + case paymentorder.FieldProviderInstanceID: + return m.OldProviderInstanceID(ctx) + case paymentorder.FieldStatus: + return m.OldStatus(ctx) + case paymentorder.FieldRefundAmount: + return m.OldRefundAmount(ctx) + case paymentorder.FieldRefundReason: + return m.OldRefundReason(ctx) + case paymentorder.FieldRefundAt: + return m.OldRefundAt(ctx) + case paymentorder.FieldForceRefund: + return m.OldForceRefund(ctx) + case paymentorder.FieldRefundRequestedAt: + return m.OldRefundRequestedAt(ctx) + case paymentorder.FieldRefundRequestReason: + return m.OldRefundRequestReason(ctx) + case paymentorder.FieldRefundRequestedBy: + return m.OldRefundRequestedBy(ctx) + case paymentorder.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case paymentorder.FieldPaidAt: + return m.OldPaidAt(ctx) + case paymentorder.FieldCompletedAt: + return m.OldCompletedAt(ctx) + case paymentorder.FieldFailedAt: + return m.OldFailedAt(ctx) + case paymentorder.FieldFailedReason: + return m.OldFailedReason(ctx) + case paymentorder.FieldClientIP: + return m.OldClientIP(ctx) + case paymentorder.FieldSrcHost: + return m.OldSrcHost(ctx) + case paymentorder.FieldSrcURL: + return m.OldSrcURL(ctx) + case paymentorder.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case paymentorder.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + } + return nil, fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { + switch name { + case paymentorder.FieldUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case paymentorder.FieldUserEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserEmail(v) + return nil + case paymentorder.FieldUserName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserName(v) + return nil + case paymentorder.FieldUserNotes: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserNotes(v) + return nil + case paymentorder.FieldAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAmount(v) + return nil + case paymentorder.FieldPayAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPayAmount(v) + return nil + case paymentorder.FieldFeeRate: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFeeRate(v) + return nil + case paymentorder.FieldRechargeCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRechargeCode(v) + return nil + case paymentorder.FieldOutTradeNo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOutTradeNo(v) + return nil + case paymentorder.FieldPaymentType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentType(v) + return nil + case paymentorder.FieldPaymentTradeNo: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentTradeNo(v) + return nil + case paymentorder.FieldPayURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPayURL(v) + return nil + case paymentorder.FieldQrCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQrCode(v) + return nil + case paymentorder.FieldQrCodeImg: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQrCodeImg(v) + return nil + case paymentorder.FieldOrderType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOrderType(v) + return nil + case paymentorder.FieldPlanID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlanID(v) + return nil + case paymentorder.FieldSubscriptionGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionGroupID(v) + return nil + case paymentorder.FieldSubscriptionDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSubscriptionDays(v) + return nil + case paymentorder.FieldProviderInstanceID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderInstanceID(v) + return nil + case paymentorder.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case paymentorder.FieldRefundAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundAmount(v) + return nil + case paymentorder.FieldRefundReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundReason(v) + return nil + case paymentorder.FieldRefundAt: + v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRefundAmount(v) + m.SetRefundAt(v) + return nil + case paymentorder.FieldForceRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetForceRefund(v) + return nil + case paymentorder.FieldRefundRequestedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestedAt(v) + return nil + case paymentorder.FieldRefundRequestReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestReason(v) + return nil + case paymentorder.FieldRefundRequestedBy: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundRequestedBy(v) + return nil + case paymentorder.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case paymentorder.FieldPaidAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaidAt(v) + return nil + case paymentorder.FieldCompletedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletedAt(v) + return nil + case paymentorder.FieldFailedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFailedAt(v) + return nil + case paymentorder.FieldFailedReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFailedReason(v) + return nil + case paymentorder.FieldClientIP: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClientIP(v) + return nil + case paymentorder.FieldSrcHost: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSrcHost(v) + return nil + case paymentorder.FieldSrcURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSrcURL(v) + return nil + case paymentorder.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case paymentorder.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + } + return fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentOrderMutation) AddedFields() []string { + var fields []string + if m.addamount != nil { + fields = append(fields, paymentorder.FieldAmount) + } + if m.addpay_amount != nil { + fields = append(fields, paymentorder.FieldPayAmount) + } + if m.addfee_rate != nil { + fields = append(fields, paymentorder.FieldFeeRate) + } + if m.addplan_id != nil { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.addsubscription_group_id != nil { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.addsubscription_days != nil { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.addrefund_amount != nil { + fields = append(fields, paymentorder.FieldRefundAmount) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case paymentorder.FieldAmount: + return m.AddedAmount() + case paymentorder.FieldPayAmount: + return m.AddedPayAmount() + case paymentorder.FieldFeeRate: + return m.AddedFeeRate() + case paymentorder.FieldPlanID: + return m.AddedPlanID() + case paymentorder.FieldSubscriptionGroupID: + return m.AddedSubscriptionGroupID() + case paymentorder.FieldSubscriptionDays: + return m.AddedSubscriptionDays() + case paymentorder.FieldRefundAmount: + return m.AddedRefundAmount() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error { + switch name { + case paymentorder.FieldAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAmount(v) + return nil + case paymentorder.FieldPayAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPayAmount(v) + return nil + case paymentorder.FieldFeeRate: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFeeRate(v) + return nil + case paymentorder.FieldPlanID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPlanID(v) + return nil + case paymentorder.FieldSubscriptionGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSubscriptionGroupID(v) + return nil + case paymentorder.FieldSubscriptionDays: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSubscriptionDays(v) + return nil + case paymentorder.FieldRefundAmount: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRefundAmount(v) + return nil + } + return fmt.Errorf("unknown PaymentOrder numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentOrderMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(paymentorder.FieldUserNotes) { + fields = append(fields, paymentorder.FieldUserNotes) + } + if m.FieldCleared(paymentorder.FieldPayURL) { + fields = append(fields, paymentorder.FieldPayURL) + } + if m.FieldCleared(paymentorder.FieldQrCode) { + fields = append(fields, paymentorder.FieldQrCode) + } + if m.FieldCleared(paymentorder.FieldQrCodeImg) { + fields = append(fields, paymentorder.FieldQrCodeImg) + } + if m.FieldCleared(paymentorder.FieldPlanID) { + fields = append(fields, paymentorder.FieldPlanID) + } + if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) { + fields = append(fields, paymentorder.FieldSubscriptionGroupID) + } + if m.FieldCleared(paymentorder.FieldSubscriptionDays) { + fields = append(fields, paymentorder.FieldSubscriptionDays) + } + if m.FieldCleared(paymentorder.FieldProviderInstanceID) { + fields = append(fields, paymentorder.FieldProviderInstanceID) + } + if m.FieldCleared(paymentorder.FieldRefundReason) { + fields = append(fields, paymentorder.FieldRefundReason) + } + if m.FieldCleared(paymentorder.FieldRefundAt) { + fields = append(fields, paymentorder.FieldRefundAt) + } + if m.FieldCleared(paymentorder.FieldRefundRequestedAt) { + fields = append(fields, paymentorder.FieldRefundRequestedAt) + } + if m.FieldCleared(paymentorder.FieldRefundRequestReason) { + fields = append(fields, paymentorder.FieldRefundRequestReason) + } + if m.FieldCleared(paymentorder.FieldRefundRequestedBy) { + fields = append(fields, paymentorder.FieldRefundRequestedBy) + } + if m.FieldCleared(paymentorder.FieldPaidAt) { + fields = append(fields, paymentorder.FieldPaidAt) + } + if m.FieldCleared(paymentorder.FieldCompletedAt) { + fields = append(fields, paymentorder.FieldCompletedAt) + } + if m.FieldCleared(paymentorder.FieldFailedAt) { + fields = append(fields, paymentorder.FieldFailedAt) + } + if m.FieldCleared(paymentorder.FieldFailedReason) { + fields = append(fields, paymentorder.FieldFailedReason) + } + if m.FieldCleared(paymentorder.FieldSrcURL) { + fields = append(fields, paymentorder.FieldSrcURL) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentOrderMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentOrderMutation) ClearField(name string) error { + switch name { + case paymentorder.FieldUserNotes: + m.ClearUserNotes() + return nil + case paymentorder.FieldPayURL: + m.ClearPayURL() + return nil + case paymentorder.FieldQrCode: + m.ClearQrCode() + return nil + case paymentorder.FieldQrCodeImg: + m.ClearQrCodeImg() + return nil + case paymentorder.FieldPlanID: + m.ClearPlanID() + return nil + case paymentorder.FieldSubscriptionGroupID: + m.ClearSubscriptionGroupID() + return nil + case paymentorder.FieldSubscriptionDays: + m.ClearSubscriptionDays() + return nil + case paymentorder.FieldProviderInstanceID: + m.ClearProviderInstanceID() return nil case paymentorder.FieldRefundReason: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundReason(v) + m.ClearRefundReason() return nil case paymentorder.FieldRefundAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundAt(v) + m.ClearRefundAt() + return nil + case paymentorder.FieldRefundRequestedAt: + m.ClearRefundRequestedAt() + return nil + case paymentorder.FieldRefundRequestReason: + m.ClearRefundRequestReason() + return nil + case paymentorder.FieldRefundRequestedBy: + m.ClearRefundRequestedBy() + return nil + case paymentorder.FieldPaidAt: + m.ClearPaidAt() + return nil + case paymentorder.FieldCompletedAt: + m.ClearCompletedAt() + return nil + case paymentorder.FieldFailedAt: + m.ClearFailedAt() + return nil + case paymentorder.FieldFailedReason: + m.ClearFailedReason() + return nil + case paymentorder.FieldSrcURL: + m.ClearSrcURL() + return nil + } + return fmt.Errorf("unknown PaymentOrder nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentOrderMutation) ResetField(name string) error { + switch name { + case paymentorder.FieldUserID: + m.ResetUserID() + return nil + case paymentorder.FieldUserEmail: + m.ResetUserEmail() + return nil + case paymentorder.FieldUserName: + m.ResetUserName() + return nil + case paymentorder.FieldUserNotes: + m.ResetUserNotes() + return nil + case paymentorder.FieldAmount: + m.ResetAmount() + return nil + case paymentorder.FieldPayAmount: + m.ResetPayAmount() + return nil + case paymentorder.FieldFeeRate: + m.ResetFeeRate() + return nil + case paymentorder.FieldRechargeCode: + m.ResetRechargeCode() + return nil + case paymentorder.FieldOutTradeNo: + m.ResetOutTradeNo() + return nil + case paymentorder.FieldPaymentType: + m.ResetPaymentType() + return nil + case paymentorder.FieldPaymentTradeNo: + m.ResetPaymentTradeNo() + return nil + case paymentorder.FieldPayURL: + m.ResetPayURL() + return nil + case paymentorder.FieldQrCode: + m.ResetQrCode() + return nil + case paymentorder.FieldQrCodeImg: + m.ResetQrCodeImg() + return nil + case paymentorder.FieldOrderType: + m.ResetOrderType() + return nil + case paymentorder.FieldPlanID: + m.ResetPlanID() + return nil + case paymentorder.FieldSubscriptionGroupID: + m.ResetSubscriptionGroupID() + return nil + case paymentorder.FieldSubscriptionDays: + m.ResetSubscriptionDays() + return nil + case paymentorder.FieldProviderInstanceID: + m.ResetProviderInstanceID() + return nil + case paymentorder.FieldStatus: + m.ResetStatus() + return nil + case paymentorder.FieldRefundAmount: + m.ResetRefundAmount() + return nil + case paymentorder.FieldRefundReason: + m.ResetRefundReason() + return nil + case paymentorder.FieldRefundAt: + m.ResetRefundAt() return nil case paymentorder.FieldForceRefund: - v, ok := value.(bool) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetForceRefund(v) + m.ResetForceRefund() return nil case paymentorder.FieldRefundRequestedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestedAt(v) + m.ResetRefundRequestedAt() return nil case paymentorder.FieldRefundRequestReason: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestReason(v) + m.ResetRefundRequestReason() return nil case paymentorder.FieldRefundRequestedBy: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRefundRequestedBy(v) + m.ResetRefundRequestedBy() return nil case paymentorder.FieldExpiresAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetExpiresAt(v) + m.ResetExpiresAt() return nil case paymentorder.FieldPaidAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPaidAt(v) + m.ResetPaidAt() return nil case paymentorder.FieldCompletedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCompletedAt(v) + m.ResetCompletedAt() return nil case paymentorder.FieldFailedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFailedAt(v) + m.ResetFailedAt() return nil case paymentorder.FieldFailedReason: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetFailedReason(v) + m.ResetFailedReason() return nil case paymentorder.FieldClientIP: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + m.ResetClientIP() + return nil + case paymentorder.FieldSrcHost: + m.ResetSrcHost() + return nil + case paymentorder.FieldSrcURL: + m.ResetSrcURL() + return nil + case paymentorder.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case paymentorder.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + } + return fmt.Errorf("unknown PaymentOrder field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *PaymentOrderMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.user != nil { + edges = append(edges, paymentorder.EdgeUser) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value { + switch name { + case paymentorder.EdgeUser: + if id := m.user; id != nil { + return []ent.Value{*id} } - m.SetClientIP(v) + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *PaymentOrderMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *PaymentOrderMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.cleareduser { + edges = append(edges, paymentorder.EdgeUser) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *PaymentOrderMutation) EdgeCleared(name string) bool { + switch name { + case paymentorder.EdgeUser: + return m.cleareduser + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentOrderMutation) ClearEdge(name string) error { + switch name { + case paymentorder.EdgeUser: + m.ClearUser() return nil - case paymentorder.FieldSrcHost: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetSrcHost(v) + } + return fmt.Errorf("unknown PaymentOrder unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentOrderMutation) ResetEdge(name string) error { + switch name { + case paymentorder.EdgeUser: + m.ResetUser() return nil - case paymentorder.FieldSrcURL: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + } + return fmt.Errorf("unknown PaymentOrder edge %s", name) +} + +// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. +type PaymentProviderInstanceMutation struct { + config + op Op + typ string + id *int64 + provider_key *string + name *string + _config *string + supported_types *string + enabled *bool + payment_mode *string + sort_order *int + addsort_order *int + limits *string + refund_enabled *bool + allow_user_refund *bool + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentProviderInstance, error) + predicates []predicate.PaymentProviderInstance +} + +var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) + +// paymentproviderinstanceOption allows management of the mutation configuration using functional options. +type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation) + +// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity. +func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation { + m := &PaymentProviderInstanceMutation{ + config: c, + op: op, + typ: TypePaymentProviderInstance, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withPaymentProviderInstanceID sets the ID field of the mutation. +func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { + return func(m *PaymentProviderInstanceMutation) { + var ( + err error + once sync.Once + value *PaymentProviderInstance + ) + m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().PaymentProviderInstance.Get(ctx, id) + } + }) + return value, err } - m.SetSrcURL(v) - return nil - case paymentorder.FieldCreatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + m.id = &id + } +} + +// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation. +func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption { + return func(m *PaymentProviderInstanceMutation) { + m.oldValue = func(context.Context) (*PaymentProviderInstance, error) { + return node, nil } - m.SetCreatedAt(v) - return nil - case paymentorder.FieldUpdatedAt: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m PaymentProviderInstanceMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil } - m.SetUpdatedAt(v) - return nil + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetProviderKey sets the "provider_key" field. +func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PaymentProviderInstanceMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetName sets the "name" field. +func (m *PaymentProviderInstanceMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *PaymentProviderInstanceMutation) ResetName() { + m.name = nil +} + +// SetConfig sets the "config" field. +func (m *PaymentProviderInstanceMutation) SetConfig(s string) { + m._config = &s +} + +// Config returns the value of the "config" field in the mutation. +func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) { + v := m._config + if v == nil { + return + } + return *v, true +} + +// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConfig is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConfig requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConfig: %w", err) + } + return oldValue.Config, nil +} + +// ResetConfig resets all changes to the "config" field. +func (m *PaymentProviderInstanceMutation) ResetConfig() { + m._config = nil +} + +// SetSupportedTypes sets the "supported_types" field. +func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) { + m.supported_types = &s +} + +// SupportedTypes returns the value of the "supported_types" field in the mutation. +func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) { + v := m.supported_types + if v == nil { + return + } + return *v, true +} + +// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedTypes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err) + } + return oldValue.SupportedTypes, nil +} + +// ResetSupportedTypes resets all changes to the "supported_types" field. +func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() { + m.supported_types = nil +} + +// SetEnabled sets the "enabled" field. +func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) { + m.enabled = &b +} + +// Enabled returns the value of the "enabled" field in the mutation. +func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) { + v := m.enabled + if v == nil { + return + } + return *v, true +} + +// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + } + return oldValue.Enabled, nil +} + +// ResetEnabled resets all changes to the "enabled" field. +func (m *PaymentProviderInstanceMutation) ResetEnabled() { + m.enabled = nil +} + +// SetPaymentMode sets the "payment_mode" field. +func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) { + m.payment_mode = &s +} + +// PaymentMode returns the value of the "payment_mode" field in the mutation. +func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) { + v := m.payment_mode + if v == nil { + return + } + return *v, true +} + +// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPaymentMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err) + } + return oldValue.PaymentMode, nil +} + +// ResetPaymentMode resets all changes to the "payment_mode" field. +func (m *PaymentProviderInstanceMutation) ResetPaymentMode() { + m.payment_mode = nil +} + +// SetSortOrder sets the "sort_order" field. +func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil +} + +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return + } + return *v, true +} + +// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil +} + +// AddSortOrder adds i to the "sort_order" field. +func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i + } +} + +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { + return + } + return *v, true +} + +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *PaymentProviderInstanceMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil +} + +// SetLimits sets the "limits" field. +func (m *PaymentProviderInstanceMutation) SetLimits(s string) { + m.limits = &s +} + +// Limits returns the value of the "limits" field in the mutation. +func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) { + v := m.limits + if v == nil { + return + } + return *v, true +} + +// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLimits is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLimits requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLimits: %w", err) + } + return oldValue.Limits, nil +} + +// ResetLimits resets all changes to the "limits" field. +func (m *PaymentProviderInstanceMutation) ResetLimits() { + m.limits = nil +} + +// SetRefundEnabled sets the "refund_enabled" field. +func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) { + m.refund_enabled = &b +} + +// RefundEnabled returns the value of the "refund_enabled" field in the mutation. +func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) { + v := m.refund_enabled + if v == nil { + return } - return fmt.Errorf("unknown PaymentOrder field %s", name) + return *v, true } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *PaymentOrderMutation) AddedFields() []string { - var fields []string - if m.addamount != nil { - fields = append(fields, paymentorder.FieldAmount) +// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations") } - if m.addpay_amount != nil { - fields = append(fields, paymentorder.FieldPayAmount) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRefundEnabled requires an ID field in the mutation") } - if m.addfee_rate != nil { - fields = append(fields, paymentorder.FieldFeeRate) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err) } - if m.addplan_id != nil { - fields = append(fields, paymentorder.FieldPlanID) + return oldValue.RefundEnabled, nil +} + +// ResetRefundEnabled resets all changes to the "refund_enabled" field. +func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { + m.refund_enabled = nil +} + +// SetAllowUserRefund sets the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { + m.allow_user_refund = &b +} + +// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. +func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { + v := m.allow_user_refund + if v == nil { + return } - if m.addsubscription_group_id != nil { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) + return *v, true +} + +// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") } - if m.addsubscription_days != nil { - fields = append(fields, paymentorder.FieldSubscriptionDays) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") } - if m.addrefund_amount != nil { - fields = append(fields, paymentorder.FieldRefundAmount) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) } - return fields + return oldValue.AllowUserRefund, nil } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *PaymentOrderMutation) AddedField(name string) (ent.Value, bool) { - switch name { - case paymentorder.FieldAmount: - return m.AddedAmount() - case paymentorder.FieldPayAmount: - return m.AddedPayAmount() - case paymentorder.FieldFeeRate: - return m.AddedFeeRate() - case paymentorder.FieldPlanID: - return m.AddedPlanID() - case paymentorder.FieldSubscriptionGroupID: - return m.AddedSubscriptionGroupID() - case paymentorder.FieldSubscriptionDays: - return m.AddedSubscriptionDays() - case paymentorder.FieldRefundAmount: - return m.AddedRefundAmount() - } - return nil, false +// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { + m.allow_user_refund = nil } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *PaymentOrderMutation) AddField(name string, value ent.Value) error { - switch name { - case paymentorder.FieldAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddAmount(v) - return nil - case paymentorder.FieldPayAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPayAmount(v) - return nil - case paymentorder.FieldFeeRate: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddFeeRate(v) - return nil - case paymentorder.FieldPlanID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddPlanID(v) - return nil - case paymentorder.FieldSubscriptionGroupID: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSubscriptionGroupID(v) - return nil - case paymentorder.FieldSubscriptionDays: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSubscriptionDays(v) - return nil - case paymentorder.FieldRefundAmount: - v, ok := value.(float64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddRefundAmount(v) - return nil +// SetCreatedAt sets the "created_at" field. +func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return } - return fmt.Errorf("unknown PaymentOrder numeric field %s", name) + return *v, true } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *PaymentOrderMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(paymentorder.FieldUserNotes) { - fields = append(fields, paymentorder.FieldUserNotes) +// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") } - if m.FieldCleared(paymentorder.FieldPayURL) { - fields = append(fields, paymentorder.FieldPayURL) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") } - if m.FieldCleared(paymentorder.FieldQrCode) { - fields = append(fields, paymentorder.FieldQrCode) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) } - if m.FieldCleared(paymentorder.FieldQrCodeImg) { - fields = append(fields, paymentorder.FieldQrCodeImg) + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PaymentProviderInstanceMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return } - if m.FieldCleared(paymentorder.FieldPlanID) { - fields = append(fields, paymentorder.FieldPlanID) + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") } - if m.FieldCleared(paymentorder.FieldSubscriptionGroupID) { - fields = append(fields, paymentorder.FieldSubscriptionGroupID) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") } - if m.FieldCleared(paymentorder.FieldSubscriptionDays) { - fields = append(fields, paymentorder.FieldSubscriptionDays) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// Where appends a list predicates to the PaymentProviderInstanceMutation builder. +func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PaymentProviderInstance, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *PaymentProviderInstanceMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *PaymentProviderInstanceMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (PaymentProviderInstance). +func (m *PaymentProviderInstanceMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *PaymentProviderInstanceMutation) Fields() []string { + fields := make([]string, 0, 12) + if m.provider_key != nil { + fields = append(fields, paymentproviderinstance.FieldProviderKey) } - if m.FieldCleared(paymentorder.FieldProviderInstanceID) { - fields = append(fields, paymentorder.FieldProviderInstanceID) + if m.name != nil { + fields = append(fields, paymentproviderinstance.FieldName) } - if m.FieldCleared(paymentorder.FieldRefundReason) { - fields = append(fields, paymentorder.FieldRefundReason) + if m._config != nil { + fields = append(fields, paymentproviderinstance.FieldConfig) } - if m.FieldCleared(paymentorder.FieldRefundAt) { - fields = append(fields, paymentorder.FieldRefundAt) + if m.supported_types != nil { + fields = append(fields, paymentproviderinstance.FieldSupportedTypes) } - if m.FieldCleared(paymentorder.FieldRefundRequestedAt) { - fields = append(fields, paymentorder.FieldRefundRequestedAt) + if m.enabled != nil { + fields = append(fields, paymentproviderinstance.FieldEnabled) } - if m.FieldCleared(paymentorder.FieldRefundRequestReason) { - fields = append(fields, paymentorder.FieldRefundRequestReason) + if m.payment_mode != nil { + fields = append(fields, paymentproviderinstance.FieldPaymentMode) } - if m.FieldCleared(paymentorder.FieldRefundRequestedBy) { - fields = append(fields, paymentorder.FieldRefundRequestedBy) + if m.sort_order != nil { + fields = append(fields, paymentproviderinstance.FieldSortOrder) } - if m.FieldCleared(paymentorder.FieldPaidAt) { - fields = append(fields, paymentorder.FieldPaidAt) + if m.limits != nil { + fields = append(fields, paymentproviderinstance.FieldLimits) } - if m.FieldCleared(paymentorder.FieldCompletedAt) { - fields = append(fields, paymentorder.FieldCompletedAt) + if m.refund_enabled != nil { + fields = append(fields, paymentproviderinstance.FieldRefundEnabled) } - if m.FieldCleared(paymentorder.FieldFailedAt) { - fields = append(fields, paymentorder.FieldFailedAt) + if m.allow_user_refund != nil { + fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) } - if m.FieldCleared(paymentorder.FieldFailedReason) { - fields = append(fields, paymentorder.FieldFailedReason) + if m.created_at != nil { + fields = append(fields, paymentproviderinstance.FieldCreatedAt) } - if m.FieldCleared(paymentorder.FieldSrcURL) { - fields = append(fields, paymentorder.FieldSrcURL) + if m.updated_at != nil { + fields = append(fields, paymentproviderinstance.FieldUpdatedAt) } return fields } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *PaymentOrderMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] - return ok +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { + switch name { + case paymentproviderinstance.FieldProviderKey: + return m.ProviderKey() + case paymentproviderinstance.FieldName: + return m.Name() + case paymentproviderinstance.FieldConfig: + return m.Config() + case paymentproviderinstance.FieldSupportedTypes: + return m.SupportedTypes() + case paymentproviderinstance.FieldEnabled: + return m.Enabled() + case paymentproviderinstance.FieldPaymentMode: + return m.PaymentMode() + case paymentproviderinstance.FieldSortOrder: + return m.SortOrder() + case paymentproviderinstance.FieldLimits: + return m.Limits() + case paymentproviderinstance.FieldRefundEnabled: + return m.RefundEnabled() + case paymentproviderinstance.FieldAllowUserRefund: + return m.AllowUserRefund() + case paymentproviderinstance.FieldCreatedAt: + return m.CreatedAt() + case paymentproviderinstance.FieldUpdatedAt: + return m.UpdatedAt() + } + return nil, false } -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *PaymentOrderMutation) ClearField(name string) error { +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case paymentorder.FieldUserNotes: - m.ClearUserNotes() - return nil - case paymentorder.FieldPayURL: - m.ClearPayURL() - return nil - case paymentorder.FieldQrCode: - m.ClearQrCode() - return nil - case paymentorder.FieldQrCodeImg: - m.ClearQrCodeImg() - return nil - case paymentorder.FieldPlanID: - m.ClearPlanID() - return nil - case paymentorder.FieldSubscriptionGroupID: - m.ClearSubscriptionGroupID() - return nil - case paymentorder.FieldSubscriptionDays: - m.ClearSubscriptionDays() - return nil - case paymentorder.FieldProviderInstanceID: - m.ClearProviderInstanceID() - return nil - case paymentorder.FieldRefundReason: - m.ClearRefundReason() - return nil - case paymentorder.FieldRefundAt: - m.ClearRefundAt() - return nil - case paymentorder.FieldRefundRequestedAt: - m.ClearRefundRequestedAt() - return nil - case paymentorder.FieldRefundRequestReason: - m.ClearRefundRequestReason() - return nil - case paymentorder.FieldRefundRequestedBy: - m.ClearRefundRequestedBy() - return nil - case paymentorder.FieldPaidAt: - m.ClearPaidAt() - return nil - case paymentorder.FieldCompletedAt: - m.ClearCompletedAt() - return nil - case paymentorder.FieldFailedAt: - m.ClearFailedAt() - return nil - case paymentorder.FieldFailedReason: - m.ClearFailedReason() - return nil - case paymentorder.FieldSrcURL: - m.ClearSrcURL() - return nil + case paymentproviderinstance.FieldProviderKey: + return m.OldProviderKey(ctx) + case paymentproviderinstance.FieldName: + return m.OldName(ctx) + case paymentproviderinstance.FieldConfig: + return m.OldConfig(ctx) + case paymentproviderinstance.FieldSupportedTypes: + return m.OldSupportedTypes(ctx) + case paymentproviderinstance.FieldEnabled: + return m.OldEnabled(ctx) + case paymentproviderinstance.FieldPaymentMode: + return m.OldPaymentMode(ctx) + case paymentproviderinstance.FieldSortOrder: + return m.OldSortOrder(ctx) + case paymentproviderinstance.FieldLimits: + return m.OldLimits(ctx) + case paymentproviderinstance.FieldRefundEnabled: + return m.OldRefundEnabled(ctx) + case paymentproviderinstance.FieldAllowUserRefund: + return m.OldAllowUserRefund(ctx) + case paymentproviderinstance.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case paymentproviderinstance.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) } - return fmt.Errorf("unknown PaymentOrder nullable field %s", name) + return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name) } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *PaymentOrderMutation) ResetField(name string) error { +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error { switch name { - case paymentorder.FieldUserID: - m.ResetUserID() - return nil - case paymentorder.FieldUserEmail: - m.ResetUserEmail() - return nil - case paymentorder.FieldUserName: - m.ResetUserName() - return nil - case paymentorder.FieldUserNotes: - m.ResetUserNotes() - return nil - case paymentorder.FieldAmount: - m.ResetAmount() - return nil - case paymentorder.FieldPayAmount: - m.ResetPayAmount() - return nil - case paymentorder.FieldFeeRate: - m.ResetFeeRate() - return nil - case paymentorder.FieldRechargeCode: - m.ResetRechargeCode() - return nil - case paymentorder.FieldOutTradeNo: - m.ResetOutTradeNo() - return nil - case paymentorder.FieldPaymentType: - m.ResetPaymentType() - return nil - case paymentorder.FieldPaymentTradeNo: - m.ResetPaymentTradeNo() - return nil - case paymentorder.FieldPayURL: - m.ResetPayURL() - return nil - case paymentorder.FieldQrCode: - m.ResetQrCode() + case paymentproviderinstance.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) return nil - case paymentorder.FieldQrCodeImg: - m.ResetQrCodeImg() + case paymentproviderinstance.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) return nil - case paymentorder.FieldOrderType: - m.ResetOrderType() + case paymentproviderinstance.FieldConfig: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConfig(v) return nil - case paymentorder.FieldPlanID: - m.ResetPlanID() + case paymentproviderinstance.FieldSupportedTypes: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedTypes(v) return nil - case paymentorder.FieldSubscriptionGroupID: - m.ResetSubscriptionGroupID() + case paymentproviderinstance.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) return nil - case paymentorder.FieldSubscriptionDays: - m.ResetSubscriptionDays() + case paymentproviderinstance.FieldPaymentMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPaymentMode(v) return nil - case paymentorder.FieldProviderInstanceID: - m.ResetProviderInstanceID() + case paymentproviderinstance.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSortOrder(v) return nil - case paymentorder.FieldStatus: - m.ResetStatus() + case paymentproviderinstance.FieldLimits: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLimits(v) return nil - case paymentorder.FieldRefundAmount: - m.ResetRefundAmount() + case paymentproviderinstance.FieldRefundEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRefundEnabled(v) return nil - case paymentorder.FieldRefundReason: - m.ResetRefundReason() + case paymentproviderinstance.FieldAllowUserRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowUserRefund(v) return nil - case paymentorder.FieldRefundAt: - m.ResetRefundAt() + case paymentproviderinstance.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) return nil - case paymentorder.FieldForceRefund: - m.ResetForceRefund() + case paymentproviderinstance.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) return nil - case paymentorder.FieldRefundRequestedAt: - m.ResetRefundRequestedAt() + } + return fmt.Errorf("unknown PaymentProviderInstance field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *PaymentProviderInstanceMutation) AddedFields() []string { + var fields []string + if m.addsort_order != nil { + fields = append(fields, paymentproviderinstance.FieldSortOrder) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case paymentproviderinstance.FieldSortOrder: + return m.AddedSortOrder() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error { + switch name { + case paymentproviderinstance.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSortOrder(v) return nil - case paymentorder.FieldRefundRequestReason: - m.ResetRefundRequestReason() + } + return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *PaymentProviderInstanceMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ClearField(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ResetField(name string) error { + switch name { + case paymentproviderinstance.FieldProviderKey: + m.ResetProviderKey() return nil - case paymentorder.FieldRefundRequestedBy: - m.ResetRefundRequestedBy() + case paymentproviderinstance.FieldName: + m.ResetName() return nil - case paymentorder.FieldExpiresAt: - m.ResetExpiresAt() + case paymentproviderinstance.FieldConfig: + m.ResetConfig() return nil - case paymentorder.FieldPaidAt: - m.ResetPaidAt() + case paymentproviderinstance.FieldSupportedTypes: + m.ResetSupportedTypes() return nil - case paymentorder.FieldCompletedAt: - m.ResetCompletedAt() + case paymentproviderinstance.FieldEnabled: + m.ResetEnabled() return nil - case paymentorder.FieldFailedAt: - m.ResetFailedAt() + case paymentproviderinstance.FieldPaymentMode: + m.ResetPaymentMode() return nil - case paymentorder.FieldFailedReason: - m.ResetFailedReason() + case paymentproviderinstance.FieldSortOrder: + m.ResetSortOrder() return nil - case paymentorder.FieldClientIP: - m.ResetClientIP() + case paymentproviderinstance.FieldLimits: + m.ResetLimits() return nil - case paymentorder.FieldSrcHost: - m.ResetSrcHost() + case paymentproviderinstance.FieldRefundEnabled: + m.ResetRefundEnabled() return nil - case paymentorder.FieldSrcURL: - m.ResetSrcURL() + case paymentproviderinstance.FieldAllowUserRefund: + m.ResetAllowUserRefund() return nil - case paymentorder.FieldCreatedAt: + case paymentproviderinstance.FieldCreatedAt: m.ResetCreatedAt() return nil - case paymentorder.FieldUpdatedAt: + case paymentproviderinstance.FieldUpdatedAt: m.ResetUpdatedAt() return nil } - return fmt.Errorf("unknown PaymentOrder field %s", name) + return fmt.Errorf("unknown PaymentProviderInstance field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentOrderMutation) AddedEdges() []string { - edges := make([]string, 0, 1) - if m.user != nil { - edges = append(edges, paymentorder.EdgeUser) - } +func (m *PaymentProviderInstanceMutation) AddedEdges() []string { + edges := make([]string, 0, 0) return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *PaymentOrderMutation) AddedIDs(name string) []ent.Value { - switch name { - case paymentorder.EdgeUser: - if id := m.user; id != nil { - return []ent.Value{*id} - } - } +func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value { return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentOrderMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) +func (m *PaymentProviderInstanceMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *PaymentOrderMutation) RemovedIDs(name string) []ent.Value { +func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentOrderMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) - if m.cleareduser { - edges = append(edges, paymentorder.EdgeUser) - } +func (m *PaymentProviderInstanceMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *PaymentOrderMutation) EdgeCleared(name string) bool { - switch name { - case paymentorder.EdgeUser: - return m.cleareduser - } - return false -} - -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *PaymentOrderMutation) ClearEdge(name string) error { - switch name { - case paymentorder.EdgeUser: - m.ClearUser() - return nil - } - return fmt.Errorf("unknown PaymentOrder unique edge %s", name) -} - -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *PaymentOrderMutation) ResetEdge(name string) error { - switch name { - case paymentorder.EdgeUser: - m.ResetUser() - return nil - } - return fmt.Errorf("unknown PaymentOrder edge %s", name) -} - -// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. -type PaymentProviderInstanceMutation struct { - config - op Op - typ string - id *int64 - provider_key *string - name *string - _config *string - supported_types *string - enabled *bool - payment_mode *string - sort_order *int - addsort_order *int - limits *string - refund_enabled *bool - allow_user_refund *bool - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentProviderInstance, error) - predicates []predicate.PaymentProviderInstance +func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool { + return false } -var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name) +} -// paymentproviderinstanceOption allows management of the mutation configuration using functional options. -type paymentproviderinstanceOption func(*PaymentProviderInstanceMutation) +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) +} -// newPaymentProviderInstanceMutation creates new mutation for the PaymentProviderInstance entity. -func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentproviderinstanceOption) *PaymentProviderInstanceMutation { - m := &PaymentProviderInstanceMutation{ +// PendingAuthSessionMutation represents an operation that mutates the PendingAuthSession nodes in the graph. +type PendingAuthSessionMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + session_token *string + intent *string + provider_type *string + provider_key *string + provider_subject *string + redirect_to *string + resolved_email *string + registration_password_hash *string + upstream_identity_claims *map[string]interface{} + local_flow_state *map[string]interface{} + browser_session_key *string + completion_code_hash *string + completion_code_expires_at *time.Time + email_verified_at *time.Time + password_verified_at *time.Time + totp_verified_at *time.Time + expires_at *time.Time + consumed_at *time.Time + clearedFields map[string]struct{} + target_user *int64 + clearedtarget_user bool + adoption_decision *int64 + clearedadoption_decision bool + done bool + oldValue func(context.Context) (*PendingAuthSession, error) + predicates []predicate.PendingAuthSession +} + +var _ ent.Mutation = (*PendingAuthSessionMutation)(nil) + +// pendingauthsessionOption allows management of the mutation configuration using functional options. +type pendingauthsessionOption func(*PendingAuthSessionMutation) + +// newPendingAuthSessionMutation creates new mutation for the PendingAuthSession entity. +func newPendingAuthSessionMutation(c config, op Op, opts ...pendingauthsessionOption) *PendingAuthSessionMutation { + m := &PendingAuthSessionMutation{ config: c, op: op, - typ: TypePaymentProviderInstance, + typ: TypePendingAuthSession, clearedFields: make(map[string]struct{}), } for _, opt := range opts { @@ -15683,20 +19272,20 @@ func newPaymentProviderInstanceMutation(c config, op Op, opts ...paymentprovider return m } -// withPaymentProviderInstanceID sets the ID field of the mutation. -func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { - return func(m *PaymentProviderInstanceMutation) { +// withPendingAuthSessionID sets the ID field of the mutation. +func withPendingAuthSessionID(id int64) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { var ( err error once sync.Once - value *PaymentProviderInstance + value *PendingAuthSession ) - m.oldValue = func(ctx context.Context) (*PaymentProviderInstance, error) { + m.oldValue = func(ctx context.Context) (*PendingAuthSession, error) { once.Do(func() { if m.done { err = errors.New("querying old values post mutation is not allowed") } else { - value, err = m.Client().PaymentProviderInstance.Get(ctx, id) + value, err = m.Client().PendingAuthSession.Get(ctx, id) } }) return value, err @@ -15705,10 +19294,10 @@ func withPaymentProviderInstanceID(id int64) paymentproviderinstanceOption { } } -// withPaymentProviderInstance sets the old PaymentProviderInstance of the mutation. -func withPaymentProviderInstance(node *PaymentProviderInstance) paymentproviderinstanceOption { - return func(m *PaymentProviderInstanceMutation) { - m.oldValue = func(context.Context) (*PaymentProviderInstance, error) { +// withPendingAuthSession sets the old PendingAuthSession of the mutation. +func withPendingAuthSession(node *PendingAuthSession) pendingauthsessionOption { + return func(m *PendingAuthSessionMutation) { + m.oldValue = func(context.Context) (*PendingAuthSession, error) { return node, nil } m.id = &node.ID @@ -15717,7 +19306,7 @@ func withPaymentProviderInstance(node *PaymentProviderInstance) paymentprovideri // Client returns a new `ent.Client` from the mutation. If the mutation was // executed in a transaction (ent.Tx), a transactional client is returned. -func (m PaymentProviderInstanceMutation) Client() *Client { +func (m PendingAuthSessionMutation) Client() *Client { client := &Client{config: m.config} client.init() return client @@ -15725,7 +19314,7 @@ func (m PaymentProviderInstanceMutation) Client() *Client { // Tx returns an `ent.Tx` for mutations that were executed in transactions; // it returns an error otherwise. -func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { +func (m PendingAuthSessionMutation) Tx() (*Tx, error) { if _, ok := m.driver.(*txDriver); !ok { return nil, errors.New("ent: mutation is not running in a transaction") } @@ -15734,495 +19323,943 @@ func (m PaymentProviderInstanceMutation) Tx() (*Tx, error) { return tx, nil } -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *PaymentProviderInstanceMutation) ID() (id int64, exists bool) { - if m.id == nil { - return - } - return *m.id, true +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *PendingAuthSessionMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *PendingAuthSessionMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().PendingAuthSession.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *PendingAuthSessionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *PendingAuthSessionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *PendingAuthSessionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *PendingAuthSessionMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *PendingAuthSessionMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *PendingAuthSessionMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetSessionToken sets the "session_token" field. +func (m *PendingAuthSessionMutation) SetSessionToken(s string) { + m.session_token = &s +} + +// SessionToken returns the value of the "session_token" field in the mutation. +func (m *PendingAuthSessionMutation) SessionToken() (r string, exists bool) { + v := m.session_token + if v == nil { + return + } + return *v, true +} + +// OldSessionToken returns the old "session_token" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldSessionToken(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSessionToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSessionToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSessionToken: %w", err) + } + return oldValue.SessionToken, nil +} + +// ResetSessionToken resets all changes to the "session_token" field. +func (m *PendingAuthSessionMutation) ResetSessionToken() { + m.session_token = nil +} + +// SetIntent sets the "intent" field. +func (m *PendingAuthSessionMutation) SetIntent(s string) { + m.intent = &s +} + +// Intent returns the value of the "intent" field in the mutation. +func (m *PendingAuthSessionMutation) Intent() (r string, exists bool) { + v := m.intent + if v == nil { + return + } + return *v, true +} + +// OldIntent returns the old "intent" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldIntent(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIntent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIntent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIntent: %w", err) + } + return oldValue.Intent, nil +} + +// ResetIntent resets all changes to the "intent" field. +func (m *PendingAuthSessionMutation) ResetIntent() { + m.intent = nil +} + +// SetProviderType sets the "provider_type" field. +func (m *PendingAuthSessionMutation) SetProviderType(s string) { + m.provider_type = &s +} + +// ProviderType returns the value of the "provider_type" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderType() (r string, exists bool) { + v := m.provider_type + if v == nil { + return + } + return *v, true +} + +// OldProviderType returns the old "provider_type" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderType: %w", err) + } + return oldValue.ProviderType, nil +} + +// ResetProviderType resets all changes to the "provider_type" field. +func (m *PendingAuthSessionMutation) ResetProviderType() { + m.provider_type = nil +} + +// SetProviderKey sets the "provider_key" field. +func (m *PendingAuthSessionMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PendingAuthSessionMutation) ResetProviderKey() { + m.provider_key = nil +} + +// SetProviderSubject sets the "provider_subject" field. +func (m *PendingAuthSessionMutation) SetProviderSubject(s string) { + m.provider_subject = &s +} + +// ProviderSubject returns the value of the "provider_subject" field in the mutation. +func (m *PendingAuthSessionMutation) ProviderSubject() (r string, exists bool) { + v := m.provider_subject + if v == nil { + return + } + return *v, true +} + +// OldProviderSubject returns the old "provider_subject" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldProviderSubject(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderSubject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderSubject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderSubject: %w", err) + } + return oldValue.ProviderSubject, nil +} + +// ResetProviderSubject resets all changes to the "provider_subject" field. +func (m *PendingAuthSessionMutation) ResetProviderSubject() { + m.provider_subject = nil +} + +// SetTargetUserID sets the "target_user_id" field. +func (m *PendingAuthSessionMutation) SetTargetUserID(i int64) { + m.target_user = &i +} + +// TargetUserID returns the value of the "target_user_id" field in the mutation. +func (m *PendingAuthSessionMutation) TargetUserID() (r int64, exists bool) { + v := m.target_user + if v == nil { + return + } + return *v, true +} + +// OldTargetUserID returns the old "target_user_id" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldTargetUserID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTargetUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTargetUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTargetUserID: %w", err) + } + return oldValue.TargetUserID, nil +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (m *PendingAuthSessionMutation) ClearTargetUserID() { + m.target_user = nil + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *PaymentProviderInstanceMutation) IDs(ctx context.Context) ([]int64, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []int64{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().PaymentProviderInstance.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) - } +// TargetUserIDCleared returns if the "target_user_id" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TargetUserIDCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTargetUserID] + return ok } -// SetProviderKey sets the "provider_key" field. -func (m *PaymentProviderInstanceMutation) SetProviderKey(s string) { - m.provider_key = &s +// ResetTargetUserID resets all changes to the "target_user_id" field. +func (m *PendingAuthSessionMutation) ResetTargetUserID() { + m.target_user = nil + delete(m.clearedFields, pendingauthsession.FieldTargetUserID) } -// ProviderKey returns the value of the "provider_key" field in the mutation. -func (m *PaymentProviderInstanceMutation) ProviderKey() (r string, exists bool) { - v := m.provider_key +// SetRedirectTo sets the "redirect_to" field. +func (m *PendingAuthSessionMutation) SetRedirectTo(s string) { + m.redirect_to = &s +} + +// RedirectTo returns the value of the "redirect_to" field in the mutation. +func (m *PendingAuthSessionMutation) RedirectTo() (r string, exists bool) { + v := m.redirect_to if v == nil { return } return *v, true } -// OldProviderKey returns the old "provider_key" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldRedirectTo returns the old "redirect_to" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldProviderKey(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldRedirectTo(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + return v, errors.New("OldRedirectTo is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldProviderKey requires an ID field in the mutation") + return v, errors.New("OldRedirectTo requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + return v, fmt.Errorf("querying old value for OldRedirectTo: %w", err) } - return oldValue.ProviderKey, nil + return oldValue.RedirectTo, nil } -// ResetProviderKey resets all changes to the "provider_key" field. -func (m *PaymentProviderInstanceMutation) ResetProviderKey() { - m.provider_key = nil +// ResetRedirectTo resets all changes to the "redirect_to" field. +func (m *PendingAuthSessionMutation) ResetRedirectTo() { + m.redirect_to = nil } -// SetName sets the "name" field. -func (m *PaymentProviderInstanceMutation) SetName(s string) { - m.name = &s +// SetResolvedEmail sets the "resolved_email" field. +func (m *PendingAuthSessionMutation) SetResolvedEmail(s string) { + m.resolved_email = &s } -// Name returns the value of the "name" field in the mutation. -func (m *PaymentProviderInstanceMutation) Name() (r string, exists bool) { - v := m.name +// ResolvedEmail returns the value of the "resolved_email" field in the mutation. +func (m *PendingAuthSessionMutation) ResolvedEmail() (r string, exists bool) { + v := m.resolved_email if v == nil { return } return *v, true } -// OldName returns the old "name" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldResolvedEmail returns the old "resolved_email" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldName(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldResolvedEmail(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") + return v, errors.New("OldResolvedEmail is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + return v, errors.New("OldResolvedEmail requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + return v, fmt.Errorf("querying old value for OldResolvedEmail: %w", err) } - return oldValue.Name, nil + return oldValue.ResolvedEmail, nil } -// ResetName resets all changes to the "name" field. -func (m *PaymentProviderInstanceMutation) ResetName() { - m.name = nil +// ResetResolvedEmail resets all changes to the "resolved_email" field. +func (m *PendingAuthSessionMutation) ResetResolvedEmail() { + m.resolved_email = nil } -// SetConfig sets the "config" field. -func (m *PaymentProviderInstanceMutation) SetConfig(s string) { - m._config = &s +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) SetRegistrationPasswordHash(s string) { + m.registration_password_hash = &s } -// Config returns the value of the "config" field in the mutation. -func (m *PaymentProviderInstanceMutation) Config() (r string, exists bool) { - v := m._config +// RegistrationPasswordHash returns the value of the "registration_password_hash" field in the mutation. +func (m *PendingAuthSessionMutation) RegistrationPasswordHash() (r string, exists bool) { + v := m.registration_password_hash if v == nil { return } return *v, true } -// OldConfig returns the old "config" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldRegistrationPasswordHash returns the old "registration_password_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldConfig(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldRegistrationPasswordHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldConfig is only allowed on UpdateOne operations") + return v, errors.New("OldRegistrationPasswordHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldConfig requires an ID field in the mutation") + return v, errors.New("OldRegistrationPasswordHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldConfig: %w", err) + return v, fmt.Errorf("querying old value for OldRegistrationPasswordHash: %w", err) } - return oldValue.Config, nil + return oldValue.RegistrationPasswordHash, nil } -// ResetConfig resets all changes to the "config" field. -func (m *PaymentProviderInstanceMutation) ResetConfig() { - m._config = nil +// ResetRegistrationPasswordHash resets all changes to the "registration_password_hash" field. +func (m *PendingAuthSessionMutation) ResetRegistrationPasswordHash() { + m.registration_password_hash = nil } -// SetSupportedTypes sets the "supported_types" field. -func (m *PaymentProviderInstanceMutation) SetSupportedTypes(s string) { - m.supported_types = &s +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) SetUpstreamIdentityClaims(value map[string]interface{}) { + m.upstream_identity_claims = &value } -// SupportedTypes returns the value of the "supported_types" field in the mutation. -func (m *PaymentProviderInstanceMutation) SupportedTypes() (r string, exists bool) { - v := m.supported_types +// UpstreamIdentityClaims returns the value of the "upstream_identity_claims" field in the mutation. +func (m *PendingAuthSessionMutation) UpstreamIdentityClaims() (r map[string]interface{}, exists bool) { + v := m.upstream_identity_claims if v == nil { return } return *v, true } -// OldSupportedTypes returns the old "supported_types" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldUpstreamIdentityClaims returns the old "upstream_identity_claims" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldSupportedTypes(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldUpstreamIdentityClaims(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSupportedTypes is only allowed on UpdateOne operations") + return v, errors.New("OldUpstreamIdentityClaims is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSupportedTypes requires an ID field in the mutation") + return v, errors.New("OldUpstreamIdentityClaims requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSupportedTypes: %w", err) + return v, fmt.Errorf("querying old value for OldUpstreamIdentityClaims: %w", err) } - return oldValue.SupportedTypes, nil + return oldValue.UpstreamIdentityClaims, nil } -// ResetSupportedTypes resets all changes to the "supported_types" field. -func (m *PaymentProviderInstanceMutation) ResetSupportedTypes() { - m.supported_types = nil +// ResetUpstreamIdentityClaims resets all changes to the "upstream_identity_claims" field. +func (m *PendingAuthSessionMutation) ResetUpstreamIdentityClaims() { + m.upstream_identity_claims = nil } -// SetEnabled sets the "enabled" field. -func (m *PaymentProviderInstanceMutation) SetEnabled(b bool) { - m.enabled = &b +// SetLocalFlowState sets the "local_flow_state" field. +func (m *PendingAuthSessionMutation) SetLocalFlowState(value map[string]interface{}) { + m.local_flow_state = &value } -// Enabled returns the value of the "enabled" field in the mutation. -func (m *PaymentProviderInstanceMutation) Enabled() (r bool, exists bool) { - v := m.enabled +// LocalFlowState returns the value of the "local_flow_state" field in the mutation. +func (m *PendingAuthSessionMutation) LocalFlowState() (r map[string]interface{}, exists bool) { + v := m.local_flow_state if v == nil { return } return *v, true } -// OldEnabled returns the old "enabled" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldLocalFlowState returns the old "local_flow_state" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldEnabled(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldLocalFlowState(ctx context.Context) (v map[string]interface{}, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldLocalFlowState is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldEnabled requires an ID field in the mutation") + return v, errors.New("OldLocalFlowState requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldLocalFlowState: %w", err) } - return oldValue.Enabled, nil + return oldValue.LocalFlowState, nil } -// ResetEnabled resets all changes to the "enabled" field. -func (m *PaymentProviderInstanceMutation) ResetEnabled() { - m.enabled = nil +// ResetLocalFlowState resets all changes to the "local_flow_state" field. +func (m *PendingAuthSessionMutation) ResetLocalFlowState() { + m.local_flow_state = nil } -// SetPaymentMode sets the "payment_mode" field. -func (m *PaymentProviderInstanceMutation) SetPaymentMode(s string) { - m.payment_mode = &s +// SetBrowserSessionKey sets the "browser_session_key" field. +func (m *PendingAuthSessionMutation) SetBrowserSessionKey(s string) { + m.browser_session_key = &s } -// PaymentMode returns the value of the "payment_mode" field in the mutation. -func (m *PaymentProviderInstanceMutation) PaymentMode() (r string, exists bool) { - v := m.payment_mode +// BrowserSessionKey returns the value of the "browser_session_key" field in the mutation. +func (m *PendingAuthSessionMutation) BrowserSessionKey() (r string, exists bool) { + v := m.browser_session_key if v == nil { return } return *v, true } -// OldPaymentMode returns the old "payment_mode" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldBrowserSessionKey returns the old "browser_session_key" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldPaymentMode(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldBrowserSessionKey(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPaymentMode is only allowed on UpdateOne operations") + return v, errors.New("OldBrowserSessionKey is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPaymentMode requires an ID field in the mutation") + return v, errors.New("OldBrowserSessionKey requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldPaymentMode: %w", err) + return v, fmt.Errorf("querying old value for OldBrowserSessionKey: %w", err) } - return oldValue.PaymentMode, nil + return oldValue.BrowserSessionKey, nil } -// ResetPaymentMode resets all changes to the "payment_mode" field. -func (m *PaymentProviderInstanceMutation) ResetPaymentMode() { - m.payment_mode = nil +// ResetBrowserSessionKey resets all changes to the "browser_session_key" field. +func (m *PendingAuthSessionMutation) ResetBrowserSessionKey() { + m.browser_session_key = nil } -// SetSortOrder sets the "sort_order" field. -func (m *PaymentProviderInstanceMutation) SetSortOrder(i int) { - m.sort_order = &i - m.addsort_order = nil +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeHash(s string) { + m.completion_code_hash = &s } -// SortOrder returns the value of the "sort_order" field in the mutation. -func (m *PaymentProviderInstanceMutation) SortOrder() (r int, exists bool) { - v := m.sort_order +// CompletionCodeHash returns the value of the "completion_code_hash" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeHash() (r string, exists bool) { + v := m.completion_code_hash if v == nil { return } return *v, true } -// OldSortOrder returns the old "sort_order" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldCompletionCodeHash returns the old "completion_code_hash" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldSortOrder(ctx context.Context) (v int, err error) { +func (m *PendingAuthSessionMutation) OldCompletionCodeHash(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + return v, errors.New("OldCompletionCodeHash is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldSortOrder requires an ID field in the mutation") + return v, errors.New("OldCompletionCodeHash requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + return v, fmt.Errorf("querying old value for OldCompletionCodeHash: %w", err) } - return oldValue.SortOrder, nil + return oldValue.CompletionCodeHash, nil } -// AddSortOrder adds i to the "sort_order" field. -func (m *PaymentProviderInstanceMutation) AddSortOrder(i int) { - if m.addsort_order != nil { - *m.addsort_order += i - } else { - m.addsort_order = &i - } +// ResetCompletionCodeHash resets all changes to the "completion_code_hash" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeHash() { + m.completion_code_hash = nil } -// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. -func (m *PaymentProviderInstanceMutation) AddedSortOrder() (r int, exists bool) { - v := m.addsort_order +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) SetCompletionCodeExpiresAt(t time.Time) { + m.completion_code_expires_at = &t +} + +// CompletionCodeExpiresAt returns the value of the "completion_code_expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAt() (r time.Time, exists bool) { + v := m.completion_code_expires_at if v == nil { return } return *v, true } -// ResetSortOrder resets all changes to the "sort_order" field. -func (m *PaymentProviderInstanceMutation) ResetSortOrder() { - m.sort_order = nil - m.addsort_order = nil +// OldCompletionCodeExpiresAt returns the old "completion_code_expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PendingAuthSessionMutation) OldCompletionCodeExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCompletionCodeExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCompletionCodeExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCompletionCodeExpiresAt: %w", err) + } + return oldValue.CompletionCodeExpiresAt, nil } -// SetLimits sets the "limits" field. -func (m *PaymentProviderInstanceMutation) SetLimits(s string) { - m.limits = &s +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ClearCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] = struct{}{} } -// Limits returns the value of the "limits" field in the mutation. -func (m *PaymentProviderInstanceMutation) Limits() (r string, exists bool) { - v := m.limits +// CompletionCodeExpiresAtCleared returns if the "completion_code_expires_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) CompletionCodeExpiresAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldCompletionCodeExpiresAt] + return ok +} + +// ResetCompletionCodeExpiresAt resets all changes to the "completion_code_expires_at" field. +func (m *PendingAuthSessionMutation) ResetCompletionCodeExpiresAt() { + m.completion_code_expires_at = nil + delete(m.clearedFields, pendingauthsession.FieldCompletionCodeExpiresAt) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (m *PendingAuthSessionMutation) SetEmailVerifiedAt(t time.Time) { + m.email_verified_at = &t +} + +// EmailVerifiedAt returns the value of the "email_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAt() (r time.Time, exists bool) { + v := m.email_verified_at if v == nil { return } return *v, true } -// OldLimits returns the old "limits" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldEmailVerifiedAt returns the old "email_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldLimits(ctx context.Context) (v string, err error) { +func (m *PendingAuthSessionMutation) OldEmailVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLimits is only allowed on UpdateOne operations") + return v, errors.New("OldEmailVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLimits requires an ID field in the mutation") + return v, errors.New("OldEmailVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldLimits: %w", err) + return v, fmt.Errorf("querying old value for OldEmailVerifiedAt: %w", err) } - return oldValue.Limits, nil + return oldValue.EmailVerifiedAt, nil } -// ResetLimits resets all changes to the "limits" field. -func (m *PaymentProviderInstanceMutation) ResetLimits() { - m.limits = nil +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ClearEmailVerifiedAt() { + m.email_verified_at = nil + m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] = struct{}{} } -// SetRefundEnabled sets the "refund_enabled" field. -func (m *PaymentProviderInstanceMutation) SetRefundEnabled(b bool) { - m.refund_enabled = &b +// EmailVerifiedAtCleared returns if the "email_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) EmailVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldEmailVerifiedAt] + return ok } -// RefundEnabled returns the value of the "refund_enabled" field in the mutation. -func (m *PaymentProviderInstanceMutation) RefundEnabled() (r bool, exists bool) { - v := m.refund_enabled +// ResetEmailVerifiedAt resets all changes to the "email_verified_at" field. +func (m *PendingAuthSessionMutation) ResetEmailVerifiedAt() { + m.email_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldEmailVerifiedAt) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (m *PendingAuthSessionMutation) SetPasswordVerifiedAt(t time.Time) { + m.password_verified_at = &t +} + +// PasswordVerifiedAt returns the value of the "password_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAt() (r time.Time, exists bool) { + v := m.password_verified_at if v == nil { return } return *v, true } -// OldRefundEnabled returns the old "refund_enabled" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldPasswordVerifiedAt returns the old "password_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldRefundEnabled(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldPasswordVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRefundEnabled is only allowed on UpdateOne operations") + return v, errors.New("OldPasswordVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRefundEnabled requires an ID field in the mutation") + return v, errors.New("OldPasswordVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRefundEnabled: %w", err) + return v, fmt.Errorf("querying old value for OldPasswordVerifiedAt: %w", err) } - return oldValue.RefundEnabled, nil + return oldValue.PasswordVerifiedAt, nil } -// ResetRefundEnabled resets all changes to the "refund_enabled" field. -func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { - m.refund_enabled = nil +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ClearPasswordVerifiedAt() { + m.password_verified_at = nil + m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] = struct{}{} } -// SetAllowUserRefund sets the "allow_user_refund" field. -func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { - m.allow_user_refund = &b +// PasswordVerifiedAtCleared returns if the "password_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) PasswordVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldPasswordVerifiedAt] + return ok } -// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. -func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { - v := m.allow_user_refund +// ResetPasswordVerifiedAt resets all changes to the "password_verified_at" field. +func (m *PendingAuthSessionMutation) ResetPasswordVerifiedAt() { + m.password_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldPasswordVerifiedAt) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) SetTotpVerifiedAt(t time.Time) { + m.totp_verified_at = &t +} + +// TotpVerifiedAt returns the value of the "totp_verified_at" field in the mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAt() (r time.Time, exists bool) { + v := m.totp_verified_at if v == nil { return } return *v, true } -// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldTotpVerifiedAt returns the old "totp_verified_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { +func (m *PendingAuthSessionMutation) OldTotpVerifiedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + return v, errors.New("OldTotpVerifiedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") + return v, errors.New("OldTotpVerifiedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) + return v, fmt.Errorf("querying old value for OldTotpVerifiedAt: %w", err) } - return oldValue.AllowUserRefund, nil + return oldValue.TotpVerifiedAt, nil } -// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. -func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { - m.allow_user_refund = nil +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ClearTotpVerifiedAt() { + m.totp_verified_at = nil + m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] = struct{}{} } -// SetCreatedAt sets the "created_at" field. -func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { - m.created_at = &t +// TotpVerifiedAtCleared returns if the "totp_verified_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) TotpVerifiedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldTotpVerifiedAt] + return ok } -// CreatedAt returns the value of the "created_at" field in the mutation. -func (m *PaymentProviderInstanceMutation) CreatedAt() (r time.Time, exists bool) { - v := m.created_at +// ResetTotpVerifiedAt resets all changes to the "totp_verified_at" field. +func (m *PendingAuthSessionMutation) ResetTotpVerifiedAt() { + m.totp_verified_at = nil + delete(m.clearedFields, pendingauthsession.FieldTotpVerifiedAt) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *PendingAuthSessionMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *PendingAuthSessionMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at if v == nil { return } return *v, true } -// OldCreatedAt returns the old "created_at" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldExpiresAt returns the old "expires_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PendingAuthSessionMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreatedAt requires an ID field in the mutation") + return v, errors.New("OldExpiresAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) } - return oldValue.CreatedAt, nil + return oldValue.ExpiresAt, nil } -// ResetCreatedAt resets all changes to the "created_at" field. -func (m *PaymentProviderInstanceMutation) ResetCreatedAt() { - m.created_at = nil +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *PendingAuthSessionMutation) ResetExpiresAt() { + m.expires_at = nil } -// SetUpdatedAt sets the "updated_at" field. -func (m *PaymentProviderInstanceMutation) SetUpdatedAt(t time.Time) { - m.updated_at = &t +// SetConsumedAt sets the "consumed_at" field. +func (m *PendingAuthSessionMutation) SetConsumedAt(t time.Time) { + m.consumed_at = &t } -// UpdatedAt returns the value of the "updated_at" field in the mutation. -func (m *PaymentProviderInstanceMutation) UpdatedAt() (r time.Time, exists bool) { - v := m.updated_at +// ConsumedAt returns the value of the "consumed_at" field in the mutation. +func (m *PendingAuthSessionMutation) ConsumedAt() (r time.Time, exists bool) { + v := m.consumed_at if v == nil { return } return *v, true } -// OldUpdatedAt returns the old "updated_at" field's value of the PaymentProviderInstance entity. -// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// OldConsumedAt returns the old "consumed_at" field's value of the PendingAuthSession entity. +// If the PendingAuthSession object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *PaymentProviderInstanceMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { +func (m *PendingAuthSessionMutation) OldConsumedAt(ctx context.Context) (v *time.Time, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + return v, errors.New("OldConsumedAt is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + return v, errors.New("OldConsumedAt requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + return v, fmt.Errorf("querying old value for OldConsumedAt: %w", err) } - return oldValue.UpdatedAt, nil + return oldValue.ConsumedAt, nil } -// ResetUpdatedAt resets all changes to the "updated_at" field. -func (m *PaymentProviderInstanceMutation) ResetUpdatedAt() { - m.updated_at = nil +// ClearConsumedAt clears the value of the "consumed_at" field. +func (m *PendingAuthSessionMutation) ClearConsumedAt() { + m.consumed_at = nil + m.clearedFields[pendingauthsession.FieldConsumedAt] = struct{}{} } -// Where appends a list predicates to the PaymentProviderInstanceMutation builder. -func (m *PaymentProviderInstanceMutation) Where(ps ...predicate.PaymentProviderInstance) { +// ConsumedAtCleared returns if the "consumed_at" field was cleared in this mutation. +func (m *PendingAuthSessionMutation) ConsumedAtCleared() bool { + _, ok := m.clearedFields[pendingauthsession.FieldConsumedAt] + return ok +} + +// ResetConsumedAt resets all changes to the "consumed_at" field. +func (m *PendingAuthSessionMutation) ResetConsumedAt() { + m.consumed_at = nil + delete(m.clearedFields, pendingauthsession.FieldConsumedAt) +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (m *PendingAuthSessionMutation) ClearTargetUser() { + m.clearedtarget_user = true + m.clearedFields[pendingauthsession.FieldTargetUserID] = struct{}{} +} + +// TargetUserCleared reports if the "target_user" edge to the User entity was cleared. +func (m *PendingAuthSessionMutation) TargetUserCleared() bool { + return m.TargetUserIDCleared() || m.clearedtarget_user +} + +// TargetUserIDs returns the "target_user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// TargetUserID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) TargetUserIDs() (ids []int64) { + if id := m.target_user; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetTargetUser resets all changes to the "target_user" edge. +func (m *PendingAuthSessionMutation) ResetTargetUser() { + m.target_user = nil + m.clearedtarget_user = false +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by id. +func (m *PendingAuthSessionMutation) SetAdoptionDecisionID(id int64) { + m.adoption_decision = &id +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (m *PendingAuthSessionMutation) ClearAdoptionDecision() { + m.clearedadoption_decision = true +} + +// AdoptionDecisionCleared reports if the "adoption_decision" edge to the IdentityAdoptionDecision entity was cleared. +func (m *PendingAuthSessionMutation) AdoptionDecisionCleared() bool { + return m.clearedadoption_decision +} + +// AdoptionDecisionID returns the "adoption_decision" edge ID in the mutation. +func (m *PendingAuthSessionMutation) AdoptionDecisionID() (id int64, exists bool) { + if m.adoption_decision != nil { + return *m.adoption_decision, true + } + return +} + +// AdoptionDecisionIDs returns the "adoption_decision" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// AdoptionDecisionID instead. It exists only for internal usage by the builders. +func (m *PendingAuthSessionMutation) AdoptionDecisionIDs() (ids []int64) { + if id := m.adoption_decision; id != nil { + ids = append(ids, *id) + } + return +} + +// ResetAdoptionDecision resets all changes to the "adoption_decision" edge. +func (m *PendingAuthSessionMutation) ResetAdoptionDecision() { + m.adoption_decision = nil + m.clearedadoption_decision = false +} + +// Where appends a list predicates to the PendingAuthSessionMutation builder. +func (m *PendingAuthSessionMutation) Where(ps ...predicate.PendingAuthSession) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the PaymentProviderInstanceMutation builder. Using this method, +// WhereP appends storage-level predicates to the PendingAuthSessionMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.PaymentProviderInstance, len(ps)) +func (m *PendingAuthSessionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PendingAuthSession, len(ps)) for i := range ps { p[i] = ps[i] } @@ -16230,60 +20267,87 @@ func (m *PaymentProviderInstanceMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *PaymentProviderInstanceMutation) Op() Op { +func (m *PendingAuthSessionMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *PaymentProviderInstanceMutation) SetOp(op Op) { +func (m *PendingAuthSessionMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (PaymentProviderInstance). -func (m *PaymentProviderInstanceMutation) Type() string { +// Type returns the node type of this mutation (PendingAuthSession). +func (m *PendingAuthSessionMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *PaymentProviderInstanceMutation) Fields() []string { - fields := make([]string, 0, 12) +func (m *PendingAuthSessionMutation) Fields() []string { + fields := make([]string, 0, 21) + if m.created_at != nil { + fields = append(fields, pendingauthsession.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, pendingauthsession.FieldUpdatedAt) + } + if m.session_token != nil { + fields = append(fields, pendingauthsession.FieldSessionToken) + } + if m.intent != nil { + fields = append(fields, pendingauthsession.FieldIntent) + } + if m.provider_type != nil { + fields = append(fields, pendingauthsession.FieldProviderType) + } if m.provider_key != nil { - fields = append(fields, paymentproviderinstance.FieldProviderKey) + fields = append(fields, pendingauthsession.FieldProviderKey) } - if m.name != nil { - fields = append(fields, paymentproviderinstance.FieldName) + if m.provider_subject != nil { + fields = append(fields, pendingauthsession.FieldProviderSubject) } - if m._config != nil { - fields = append(fields, paymentproviderinstance.FieldConfig) + if m.target_user != nil { + fields = append(fields, pendingauthsession.FieldTargetUserID) } - if m.supported_types != nil { - fields = append(fields, paymentproviderinstance.FieldSupportedTypes) + if m.redirect_to != nil { + fields = append(fields, pendingauthsession.FieldRedirectTo) } - if m.enabled != nil { - fields = append(fields, paymentproviderinstance.FieldEnabled) + if m.resolved_email != nil { + fields = append(fields, pendingauthsession.FieldResolvedEmail) } - if m.payment_mode != nil { - fields = append(fields, paymentproviderinstance.FieldPaymentMode) + if m.registration_password_hash != nil { + fields = append(fields, pendingauthsession.FieldRegistrationPasswordHash) } - if m.sort_order != nil { - fields = append(fields, paymentproviderinstance.FieldSortOrder) + if m.upstream_identity_claims != nil { + fields = append(fields, pendingauthsession.FieldUpstreamIdentityClaims) } - if m.limits != nil { - fields = append(fields, paymentproviderinstance.FieldLimits) + if m.local_flow_state != nil { + fields = append(fields, pendingauthsession.FieldLocalFlowState) } - if m.refund_enabled != nil { - fields = append(fields, paymentproviderinstance.FieldRefundEnabled) + if m.browser_session_key != nil { + fields = append(fields, pendingauthsession.FieldBrowserSessionKey) } - if m.allow_user_refund != nil { - fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) + if m.completion_code_hash != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeHash) } - if m.created_at != nil { - fields = append(fields, paymentproviderinstance.FieldCreatedAt) + if m.completion_code_expires_at != nil { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) } - if m.updated_at != nil { - fields = append(fields, paymentproviderinstance.FieldUpdatedAt) + if m.email_verified_at != nil { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.password_verified_at != nil { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.totp_verified_at != nil { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.expires_at != nil { + fields = append(fields, pendingauthsession.FieldExpiresAt) + } + if m.consumed_at != nil { + fields = append(fields, pendingauthsession.FieldConsumedAt) } return fields } @@ -16291,32 +20355,50 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { +func (m *PendingAuthSessionMutation) Field(name string) (ent.Value, bool) { switch name { - case paymentproviderinstance.FieldProviderKey: - return m.ProviderKey() - case paymentproviderinstance.FieldName: - return m.Name() - case paymentproviderinstance.FieldConfig: - return m.Config() - case paymentproviderinstance.FieldSupportedTypes: - return m.SupportedTypes() - case paymentproviderinstance.FieldEnabled: - return m.Enabled() - case paymentproviderinstance.FieldPaymentMode: - return m.PaymentMode() - case paymentproviderinstance.FieldSortOrder: - return m.SortOrder() - case paymentproviderinstance.FieldLimits: - return m.Limits() - case paymentproviderinstance.FieldRefundEnabled: - return m.RefundEnabled() - case paymentproviderinstance.FieldAllowUserRefund: - return m.AllowUserRefund() - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldCreatedAt: return m.CreatedAt() - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldUpdatedAt: return m.UpdatedAt() + case pendingauthsession.FieldSessionToken: + return m.SessionToken() + case pendingauthsession.FieldIntent: + return m.Intent() + case pendingauthsession.FieldProviderType: + return m.ProviderType() + case pendingauthsession.FieldProviderKey: + return m.ProviderKey() + case pendingauthsession.FieldProviderSubject: + return m.ProviderSubject() + case pendingauthsession.FieldTargetUserID: + return m.TargetUserID() + case pendingauthsession.FieldRedirectTo: + return m.RedirectTo() + case pendingauthsession.FieldResolvedEmail: + return m.ResolvedEmail() + case pendingauthsession.FieldRegistrationPasswordHash: + return m.RegistrationPasswordHash() + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.UpstreamIdentityClaims() + case pendingauthsession.FieldLocalFlowState: + return m.LocalFlowState() + case pendingauthsession.FieldBrowserSessionKey: + return m.BrowserSessionKey() + case pendingauthsession.FieldCompletionCodeHash: + return m.CompletionCodeHash() + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.CompletionCodeExpiresAt() + case pendingauthsession.FieldEmailVerifiedAt: + return m.EmailVerifiedAt() + case pendingauthsession.FieldPasswordVerifiedAt: + return m.PasswordVerifiedAt() + case pendingauthsession.FieldTotpVerifiedAt: + return m.TotpVerifiedAt() + case pendingauthsession.FieldExpiresAt: + return m.ExpiresAt() + case pendingauthsession.FieldConsumedAt: + return m.ConsumedAt() } return nil, false } @@ -16324,146 +20406,222 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *PendingAuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case paymentproviderinstance.FieldProviderKey: - return m.OldProviderKey(ctx) - case paymentproviderinstance.FieldName: - return m.OldName(ctx) - case paymentproviderinstance.FieldConfig: - return m.OldConfig(ctx) - case paymentproviderinstance.FieldSupportedTypes: - return m.OldSupportedTypes(ctx) - case paymentproviderinstance.FieldEnabled: - return m.OldEnabled(ctx) - case paymentproviderinstance.FieldPaymentMode: - return m.OldPaymentMode(ctx) - case paymentproviderinstance.FieldSortOrder: - return m.OldSortOrder(ctx) - case paymentproviderinstance.FieldLimits: - return m.OldLimits(ctx) - case paymentproviderinstance.FieldRefundEnabled: - return m.OldRefundEnabled(ctx) - case paymentproviderinstance.FieldAllowUserRefund: - return m.OldAllowUserRefund(ctx) - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldCreatedAt: return m.OldCreatedAt(ctx) - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldUpdatedAt: return m.OldUpdatedAt(ctx) + case pendingauthsession.FieldSessionToken: + return m.OldSessionToken(ctx) + case pendingauthsession.FieldIntent: + return m.OldIntent(ctx) + case pendingauthsession.FieldProviderType: + return m.OldProviderType(ctx) + case pendingauthsession.FieldProviderKey: + return m.OldProviderKey(ctx) + case pendingauthsession.FieldProviderSubject: + return m.OldProviderSubject(ctx) + case pendingauthsession.FieldTargetUserID: + return m.OldTargetUserID(ctx) + case pendingauthsession.FieldRedirectTo: + return m.OldRedirectTo(ctx) + case pendingauthsession.FieldResolvedEmail: + return m.OldResolvedEmail(ctx) + case pendingauthsession.FieldRegistrationPasswordHash: + return m.OldRegistrationPasswordHash(ctx) + case pendingauthsession.FieldUpstreamIdentityClaims: + return m.OldUpstreamIdentityClaims(ctx) + case pendingauthsession.FieldLocalFlowState: + return m.OldLocalFlowState(ctx) + case pendingauthsession.FieldBrowserSessionKey: + return m.OldBrowserSessionKey(ctx) + case pendingauthsession.FieldCompletionCodeHash: + return m.OldCompletionCodeHash(ctx) + case pendingauthsession.FieldCompletionCodeExpiresAt: + return m.OldCompletionCodeExpiresAt(ctx) + case pendingauthsession.FieldEmailVerifiedAt: + return m.OldEmailVerifiedAt(ctx) + case pendingauthsession.FieldPasswordVerifiedAt: + return m.OldPasswordVerifiedAt(ctx) + case pendingauthsession.FieldTotpVerifiedAt: + return m.OldTotpVerifiedAt(ctx) + case pendingauthsession.FieldExpiresAt: + return m.OldExpiresAt(ctx) + case pendingauthsession.FieldConsumedAt: + return m.OldConsumedAt(ctx) } - return nil, fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return nil, fmt.Errorf("unknown PendingAuthSession field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) error { +func (m *PendingAuthSessionMutation) SetField(name string, value ent.Value) error { switch name { - case paymentproviderinstance.FieldProviderKey: + case pendingauthsession.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case pendingauthsession.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case pendingauthsession.FieldSessionToken: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetProviderKey(v) + m.SetSessionToken(v) return nil - case paymentproviderinstance.FieldName: + case pendingauthsession.FieldIntent: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetName(v) + m.SetIntent(v) return nil - case paymentproviderinstance.FieldConfig: + case pendingauthsession.FieldProviderType: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetConfig(v) + m.SetProviderType(v) return nil - case paymentproviderinstance.FieldSupportedTypes: + case pendingauthsession.FieldProviderKey: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSupportedTypes(v) + m.SetProviderKey(v) return nil - case paymentproviderinstance.FieldEnabled: - v, ok := value.(bool) + case pendingauthsession.FieldProviderSubject: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetEnabled(v) + m.SetProviderSubject(v) return nil - case paymentproviderinstance.FieldPaymentMode: + case pendingauthsession.FieldTargetUserID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTargetUserID(v) + return nil + case pendingauthsession.FieldRedirectTo: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetPaymentMode(v) + m.SetRedirectTo(v) return nil - case paymentproviderinstance.FieldSortOrder: - v, ok := value.(int) + case pendingauthsession.FieldResolvedEmail: + v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetSortOrder(v) + m.SetResolvedEmail(v) return nil - case paymentproviderinstance.FieldLimits: + case pendingauthsession.FieldRegistrationPasswordHash: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetLimits(v) + m.SetRegistrationPasswordHash(v) return nil - case paymentproviderinstance.FieldRefundEnabled: - v, ok := value.(bool) + case pendingauthsession.FieldUpstreamIdentityClaims: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRefundEnabled(v) + m.SetUpstreamIdentityClaims(v) return nil - case paymentproviderinstance.FieldAllowUserRefund: - v, ok := value.(bool) + case pendingauthsession.FieldLocalFlowState: + v, ok := value.(map[string]interface{}) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetAllowUserRefund(v) + m.SetLocalFlowState(v) return nil - case paymentproviderinstance.FieldCreatedAt: + case pendingauthsession.FieldBrowserSessionKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBrowserSessionKey(v) + return nil + case pendingauthsession.FieldCompletionCodeHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCompletionCodeHash(v) + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetCreatedAt(v) + m.SetCompletionCodeExpiresAt(v) return nil - case paymentproviderinstance.FieldUpdatedAt: + case pendingauthsession.FieldEmailVerifiedAt: v, ok := value.(time.Time) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetUpdatedAt(v) + m.SetEmailVerifiedAt(v) + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPasswordVerifiedAt(v) + return nil + case pendingauthsession.FieldTotpVerifiedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpVerifiedAt(v) + return nil + case pendingauthsession.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + case pendingauthsession.FieldConsumedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConsumedAt(v) return nil } - return fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return fmt.Errorf("unknown PendingAuthSession field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *PaymentProviderInstanceMutation) AddedFields() []string { +func (m *PendingAuthSessionMutation) AddedFields() []string { var fields []string - if m.addsort_order != nil { - fields = append(fields, paymentproviderinstance.FieldSortOrder) - } return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bool) { +func (m *PendingAuthSessionMutation) AddedField(name string) (ent.Value, bool) { switch name { - case paymentproviderinstance.FieldSortOrder: - return m.AddedSortOrder() } return nil, false } @@ -16471,128 +20629,231 @@ func (m *PaymentProviderInstanceMutation) AddedField(name string) (ent.Value, bo // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *PaymentProviderInstanceMutation) AddField(name string, value ent.Value) error { +func (m *PendingAuthSessionMutation) AddField(name string, value ent.Value) error { switch name { - case paymentproviderinstance.FieldSortOrder: - v, ok := value.(int) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddSortOrder(v) - return nil } - return fmt.Errorf("unknown PaymentProviderInstance numeric field %s", name) + return fmt.Errorf("unknown PendingAuthSession numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *PaymentProviderInstanceMutation) ClearedFields() []string { - return nil +func (m *PendingAuthSessionMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(pendingauthsession.FieldTargetUserID) { + fields = append(fields, pendingauthsession.FieldTargetUserID) + } + if m.FieldCleared(pendingauthsession.FieldCompletionCodeExpiresAt) { + fields = append(fields, pendingauthsession.FieldCompletionCodeExpiresAt) + } + if m.FieldCleared(pendingauthsession.FieldEmailVerifiedAt) { + fields = append(fields, pendingauthsession.FieldEmailVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldPasswordVerifiedAt) { + fields = append(fields, pendingauthsession.FieldPasswordVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldTotpVerifiedAt) { + fields = append(fields, pendingauthsession.FieldTotpVerifiedAt) + } + if m.FieldCleared(pendingauthsession.FieldConsumedAt) { + fields = append(fields, pendingauthsession.FieldConsumedAt) + } + return fields } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *PaymentProviderInstanceMutation) FieldCleared(name string) bool { +func (m *PendingAuthSessionMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ClearField(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance nullable field %s", name) +func (m *PendingAuthSessionMutation) ClearField(name string) error { + switch name { + case pendingauthsession.FieldTargetUserID: + m.ClearTargetUserID() + return nil + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ClearCompletionCodeExpiresAt() + return nil + case pendingauthsession.FieldEmailVerifiedAt: + m.ClearEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ClearPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ClearTotpVerifiedAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ClearConsumedAt() + return nil + } + return fmt.Errorf("unknown PendingAuthSession nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ResetField(name string) error { +func (m *PendingAuthSessionMutation) ResetField(name string) error { switch name { - case paymentproviderinstance.FieldProviderKey: + case pendingauthsession.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case pendingauthsession.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case pendingauthsession.FieldSessionToken: + m.ResetSessionToken() + return nil + case pendingauthsession.FieldIntent: + m.ResetIntent() + return nil + case pendingauthsession.FieldProviderType: + m.ResetProviderType() + return nil + case pendingauthsession.FieldProviderKey: m.ResetProviderKey() return nil - case paymentproviderinstance.FieldName: - m.ResetName() + case pendingauthsession.FieldProviderSubject: + m.ResetProviderSubject() return nil - case paymentproviderinstance.FieldConfig: - m.ResetConfig() + case pendingauthsession.FieldTargetUserID: + m.ResetTargetUserID() return nil - case paymentproviderinstance.FieldSupportedTypes: - m.ResetSupportedTypes() + case pendingauthsession.FieldRedirectTo: + m.ResetRedirectTo() return nil - case paymentproviderinstance.FieldEnabled: - m.ResetEnabled() + case pendingauthsession.FieldResolvedEmail: + m.ResetResolvedEmail() return nil - case paymentproviderinstance.FieldPaymentMode: - m.ResetPaymentMode() + case pendingauthsession.FieldRegistrationPasswordHash: + m.ResetRegistrationPasswordHash() return nil - case paymentproviderinstance.FieldSortOrder: - m.ResetSortOrder() + case pendingauthsession.FieldUpstreamIdentityClaims: + m.ResetUpstreamIdentityClaims() return nil - case paymentproviderinstance.FieldLimits: - m.ResetLimits() + case pendingauthsession.FieldLocalFlowState: + m.ResetLocalFlowState() return nil - case paymentproviderinstance.FieldRefundEnabled: - m.ResetRefundEnabled() + case pendingauthsession.FieldBrowserSessionKey: + m.ResetBrowserSessionKey() return nil - case paymentproviderinstance.FieldAllowUserRefund: - m.ResetAllowUserRefund() + case pendingauthsession.FieldCompletionCodeHash: + m.ResetCompletionCodeHash() return nil - case paymentproviderinstance.FieldCreatedAt: - m.ResetCreatedAt() + case pendingauthsession.FieldCompletionCodeExpiresAt: + m.ResetCompletionCodeExpiresAt() return nil - case paymentproviderinstance.FieldUpdatedAt: - m.ResetUpdatedAt() + case pendingauthsession.FieldEmailVerifiedAt: + m.ResetEmailVerifiedAt() + return nil + case pendingauthsession.FieldPasswordVerifiedAt: + m.ResetPasswordVerifiedAt() + return nil + case pendingauthsession.FieldTotpVerifiedAt: + m.ResetTotpVerifiedAt() + return nil + case pendingauthsession.FieldExpiresAt: + m.ResetExpiresAt() + return nil + case pendingauthsession.FieldConsumedAt: + m.ResetConsumedAt() return nil } - return fmt.Errorf("unknown PaymentProviderInstance field %s", name) + return fmt.Errorf("unknown PendingAuthSession field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *PaymentProviderInstanceMutation) AddedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.target_user != nil { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.adoption_decision != nil { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *PaymentProviderInstanceMutation) AddedIDs(name string) []ent.Value { +func (m *PendingAuthSessionMutation) AddedIDs(name string) []ent.Value { + switch name { + case pendingauthsession.EdgeTargetUser: + if id := m.target_user; id != nil { + return []ent.Value{*id} + } + case pendingauthsession.EdgeAdoptionDecision: + if id := m.adoption_decision; id != nil { + return []ent.Value{*id} + } + } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *PaymentProviderInstanceMutation) RemovedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *PaymentProviderInstanceMutation) RemovedIDs(name string) []ent.Value { +func (m *PendingAuthSessionMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *PaymentProviderInstanceMutation) ClearedEdges() []string { - edges := make([]string, 0, 0) +func (m *PendingAuthSessionMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedtarget_user { + edges = append(edges, pendingauthsession.EdgeTargetUser) + } + if m.clearedadoption_decision { + edges = append(edges, pendingauthsession.EdgeAdoptionDecision) + } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *PaymentProviderInstanceMutation) EdgeCleared(name string) bool { +func (m *PendingAuthSessionMutation) EdgeCleared(name string) bool { + switch name { + case pendingauthsession.EdgeTargetUser: + return m.clearedtarget_user + case pendingauthsession.EdgeAdoptionDecision: + return m.clearedadoption_decision + } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ClearEdge(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance unique edge %s", name) +func (m *PendingAuthSessionMutation) ClearEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ClearTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ClearAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *PaymentProviderInstanceMutation) ResetEdge(name string) error { - return fmt.Errorf("unknown PaymentProviderInstance edge %s", name) +func (m *PendingAuthSessionMutation) ResetEdge(name string) error { + switch name { + case pendingauthsession.EdgeTargetUser: + m.ResetTargetUser() + return nil + case pendingauthsession.EdgeAdoptionDecision: + m.ResetAdoptionDecision() + return nil + } + return fmt.Errorf("unknown PendingAuthSession edge %s", name) } // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. @@ -28264,6 +32525,9 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + signup_source *string + last_login_at *time.Time + last_active_at *time.Time balance_notify_enabled *bool balance_notify_threshold_type *string balance_notify_threshold *float64 @@ -28302,6 +32566,12 @@ type UserMutation struct { payment_orders map[int64]struct{} removedpayment_orders map[int64]struct{} clearedpayment_orders bool + auth_identities map[int64]struct{} + removedauth_identities map[int64]struct{} + clearedauth_identities bool + pending_auth_sessions map[int64]struct{} + removedpending_auth_sessions map[int64]struct{} + clearedpending_auth_sessions bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -28988,6 +33258,140 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSignupSource sets the "signup_source" field. +func (m *UserMutation) SetSignupSource(s string) { + m.signup_source = &s +} + +// SignupSource returns the value of the "signup_source" field in the mutation. +func (m *UserMutation) SignupSource() (r string, exists bool) { + v := m.signup_source + if v == nil { + return + } + return *v, true +} + +// OldSignupSource returns the old "signup_source" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSignupSource(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSignupSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSignupSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSignupSource: %w", err) + } + return oldValue.SignupSource, nil +} + +// ResetSignupSource resets all changes to the "signup_source" field. +func (m *UserMutation) ResetSignupSource() { + m.signup_source = nil +} + +// SetLastLoginAt sets the "last_login_at" field. +func (m *UserMutation) SetLastLoginAt(t time.Time) { + m.last_login_at = &t +} + +// LastLoginAt returns the value of the "last_login_at" field in the mutation. +func (m *UserMutation) LastLoginAt() (r time.Time, exists bool) { + v := m.last_login_at + if v == nil { + return + } + return *v, true +} + +// OldLastLoginAt returns the old "last_login_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastLoginAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastLoginAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastLoginAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastLoginAt: %w", err) + } + return oldValue.LastLoginAt, nil +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (m *UserMutation) ClearLastLoginAt() { + m.last_login_at = nil + m.clearedFields[user.FieldLastLoginAt] = struct{}{} +} + +// LastLoginAtCleared returns if the "last_login_at" field was cleared in this mutation. +func (m *UserMutation) LastLoginAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastLoginAt] + return ok +} + +// ResetLastLoginAt resets all changes to the "last_login_at" field. +func (m *UserMutation) ResetLastLoginAt() { + m.last_login_at = nil + delete(m.clearedFields, user.FieldLastLoginAt) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (m *UserMutation) SetLastActiveAt(t time.Time) { + m.last_active_at = &t +} + +// LastActiveAt returns the value of the "last_active_at" field in the mutation. +func (m *UserMutation) LastActiveAt() (r time.Time, exists bool) { + v := m.last_active_at + if v == nil { + return + } + return *v, true +} + +// OldLastActiveAt returns the old "last_active_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldLastActiveAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastActiveAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastActiveAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastActiveAt: %w", err) + } + return oldValue.LastActiveAt, nil +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (m *UserMutation) ClearLastActiveAt() { + m.last_active_at = nil + m.clearedFields[user.FieldLastActiveAt] = struct{}{} +} + +// LastActiveAtCleared returns if the "last_active_at" field was cleared in this mutation. +func (m *UserMutation) LastActiveAtCleared() bool { + _, ok := m.clearedFields[user.FieldLastActiveAt] + return ok +} + +// ResetLastActiveAt resets all changes to the "last_active_at" field. +func (m *UserMutation) ResetLastActiveAt() { + m.last_active_at = nil + delete(m.clearedFields, user.FieldLastActiveAt) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (m *UserMutation) SetBalanceNotifyEnabled(b bool) { m.balance_notify_enabled = &b @@ -29762,6 +34166,114 @@ func (m *UserMutation) ResetPaymentOrders() { m.removedpayment_orders = nil } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by ids. +func (m *UserMutation) AddAuthIdentityIDs(ids ...int64) { + if m.auth_identities == nil { + m.auth_identities = make(map[int64]struct{}) + } + for i := range ids { + m.auth_identities[ids[i]] = struct{}{} + } +} + +// ClearAuthIdentities clears the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) ClearAuthIdentities() { + m.clearedauth_identities = true +} + +// AuthIdentitiesCleared reports if the "auth_identities" edge to the AuthIdentity entity was cleared. +func (m *UserMutation) AuthIdentitiesCleared() bool { + return m.clearedauth_identities +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to the AuthIdentity entity by IDs. +func (m *UserMutation) RemoveAuthIdentityIDs(ids ...int64) { + if m.removedauth_identities == nil { + m.removedauth_identities = make(map[int64]struct{}) + } + for i := range ids { + delete(m.auth_identities, ids[i]) + m.removedauth_identities[ids[i]] = struct{}{} + } +} + +// RemovedAuthIdentities returns the removed IDs of the "auth_identities" edge to the AuthIdentity entity. +func (m *UserMutation) RemovedAuthIdentitiesIDs() (ids []int64) { + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return +} + +// AuthIdentitiesIDs returns the "auth_identities" edge IDs in the mutation. +func (m *UserMutation) AuthIdentitiesIDs() (ids []int64) { + for id := range m.auth_identities { + ids = append(ids, id) + } + return +} + +// ResetAuthIdentities resets all changes to the "auth_identities" edge. +func (m *UserMutation) ResetAuthIdentities() { + m.auth_identities = nil + m.clearedauth_identities = false + m.removedauth_identities = nil +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by ids. +func (m *UserMutation) AddPendingAuthSessionIDs(ids ...int64) { + if m.pending_auth_sessions == nil { + m.pending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + m.pending_auth_sessions[ids[i]] = struct{}{} + } +} + +// ClearPendingAuthSessions clears the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) ClearPendingAuthSessions() { + m.clearedpending_auth_sessions = true +} + +// PendingAuthSessionsCleared reports if the "pending_auth_sessions" edge to the PendingAuthSession entity was cleared. +func (m *UserMutation) PendingAuthSessionsCleared() bool { + return m.clearedpending_auth_sessions +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (m *UserMutation) RemovePendingAuthSessionIDs(ids ...int64) { + if m.removedpending_auth_sessions == nil { + m.removedpending_auth_sessions = make(map[int64]struct{}) + } + for i := range ids { + delete(m.pending_auth_sessions, ids[i]) + m.removedpending_auth_sessions[ids[i]] = struct{}{} + } +} + +// RemovedPendingAuthSessions returns the removed IDs of the "pending_auth_sessions" edge to the PendingAuthSession entity. +func (m *UserMutation) RemovedPendingAuthSessionsIDs() (ids []int64) { + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return +} + +// PendingAuthSessionsIDs returns the "pending_auth_sessions" edge IDs in the mutation. +func (m *UserMutation) PendingAuthSessionsIDs() (ids []int64) { + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return +} + +// ResetPendingAuthSessions resets all changes to the "pending_auth_sessions" edge. +func (m *UserMutation) ResetPendingAuthSessions() { + m.pending_auth_sessions = nil + m.clearedpending_auth_sessions = false + m.removedpending_auth_sessions = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -29796,7 +34308,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 19) + fields := make([]string, 0, 22) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -29839,6 +34351,15 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.signup_source != nil { + fields = append(fields, user.FieldSignupSource) + } + if m.last_login_at != nil { + fields = append(fields, user.FieldLastLoginAt) + } + if m.last_active_at != nil { + fields = append(fields, user.FieldLastActiveAt) + } if m.balance_notify_enabled != nil { fields = append(fields, user.FieldBalanceNotifyEnabled) } @@ -29890,6 +34411,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSignupSource: + return m.SignupSource() + case user.FieldLastLoginAt: + return m.LastLoginAt() + case user.FieldLastActiveAt: + return m.LastActiveAt() case user.FieldBalanceNotifyEnabled: return m.BalanceNotifyEnabled() case user.FieldBalanceNotifyThresholdType: @@ -29937,6 +34464,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSignupSource: + return m.OldSignupSource(ctx) + case user.FieldLastLoginAt: + return m.OldLastLoginAt(ctx) + case user.FieldLastActiveAt: + return m.OldLastActiveAt(ctx) case user.FieldBalanceNotifyEnabled: return m.OldBalanceNotifyEnabled(ctx) case user.FieldBalanceNotifyThresholdType: @@ -30054,6 +34587,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSignupSource: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSignupSource(v) + return nil + case user.FieldLastLoginAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastLoginAt(v) + return nil + case user.FieldLastActiveAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastActiveAt(v) + return nil case user.FieldBalanceNotifyEnabled: v, ok := value.(bool) if !ok { @@ -30179,6 +34733,12 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldTotpEnabledAt) { fields = append(fields, user.FieldTotpEnabledAt) } + if m.FieldCleared(user.FieldLastLoginAt) { + fields = append(fields, user.FieldLastLoginAt) + } + if m.FieldCleared(user.FieldLastActiveAt) { + fields = append(fields, user.FieldLastActiveAt) + } if m.FieldCleared(user.FieldBalanceNotifyThreshold) { fields = append(fields, user.FieldBalanceNotifyThreshold) } @@ -30205,6 +34765,12 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldTotpEnabledAt: m.ClearTotpEnabledAt() return nil + case user.FieldLastLoginAt: + m.ClearLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ClearLastActiveAt() + return nil case user.FieldBalanceNotifyThreshold: m.ClearBalanceNotifyThreshold() return nil @@ -30258,6 +34824,15 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSignupSource: + m.ResetSignupSource() + return nil + case user.FieldLastLoginAt: + m.ResetLastLoginAt() + return nil + case user.FieldLastActiveAt: + m.ResetLastActiveAt() + return nil case user.FieldBalanceNotifyEnabled: m.ResetBalanceNotifyEnabled() return nil @@ -30279,7 +34854,7 @@ func (m *UserMutation) ResetField(name string) error { // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.api_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30310,6 +34885,12 @@ func (m *UserMutation) AddedEdges() []string { if m.payment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.auth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.pending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30377,13 +34958,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.auth_identities)) + for id := range m.auth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.pending_auth_sessions)) + for id := range m.pending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.removedapi_keys != nil { edges = append(edges, user.EdgeAPIKeys) } @@ -30414,6 +35007,12 @@ func (m *UserMutation) RemovedEdges() []string { if m.removedpayment_orders != nil { edges = append(edges, user.EdgePaymentOrders) } + if m.removedauth_identities != nil { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.removedpending_auth_sessions != nil { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30481,13 +35080,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeAuthIdentities: + ids := make([]ent.Value, 0, len(m.removedauth_identities)) + for id := range m.removedauth_identities { + ids = append(ids, id) + } + return ids + case user.EdgePendingAuthSessions: + ids := make([]ent.Value, 0, len(m.removedpending_auth_sessions)) + for id := range m.removedpending_auth_sessions { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 10) + edges := make([]string, 0, 12) if m.clearedapi_keys { edges = append(edges, user.EdgeAPIKeys) } @@ -30518,6 +35129,12 @@ func (m *UserMutation) ClearedEdges() []string { if m.clearedpayment_orders { edges = append(edges, user.EdgePaymentOrders) } + if m.clearedauth_identities { + edges = append(edges, user.EdgeAuthIdentities) + } + if m.clearedpending_auth_sessions { + edges = append(edges, user.EdgePendingAuthSessions) + } return edges } @@ -30545,6 +35162,10 @@ func (m *UserMutation) EdgeCleared(name string) bool { return m.clearedpromo_code_usages case user.EdgePaymentOrders: return m.clearedpayment_orders + case user.EdgeAuthIdentities: + return m.clearedauth_identities + case user.EdgePendingAuthSessions: + return m.clearedpending_auth_sessions } return false } @@ -30591,6 +35212,12 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePaymentOrders: m.ResetPaymentOrders() return nil + case user.EdgeAuthIdentities: + m.ResetAuthIdentities() + return nil + case user.EdgePendingAuthSessions: + m.ResetPendingAuthSessions() + return nil } return fmt.Errorf("unknown User edge %s", name) } diff --git a/backend/ent/pendingauthsession.go b/backend/ent/pendingauthsession.go new file mode 100644 index 00000000..e77c065f --- /dev/null +++ b/backend/ent/pendingauthsession.go @@ -0,0 +1,399 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSession is the model entity for the PendingAuthSession schema. +type PendingAuthSession struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // SessionToken holds the value of the "session_token" field. + SessionToken string `json:"session_token,omitempty"` + // Intent holds the value of the "intent" field. + Intent string `json:"intent,omitempty"` + // ProviderType holds the value of the "provider_type" field. + ProviderType string `json:"provider_type,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey string `json:"provider_key,omitempty"` + // ProviderSubject holds the value of the "provider_subject" field. + ProviderSubject string `json:"provider_subject,omitempty"` + // TargetUserID holds the value of the "target_user_id" field. + TargetUserID *int64 `json:"target_user_id,omitempty"` + // RedirectTo holds the value of the "redirect_to" field. + RedirectTo string `json:"redirect_to,omitempty"` + // ResolvedEmail holds the value of the "resolved_email" field. + ResolvedEmail string `json:"resolved_email,omitempty"` + // RegistrationPasswordHash holds the value of the "registration_password_hash" field. + RegistrationPasswordHash string `json:"registration_password_hash,omitempty"` + // UpstreamIdentityClaims holds the value of the "upstream_identity_claims" field. + UpstreamIdentityClaims map[string]interface{} `json:"upstream_identity_claims,omitempty"` + // LocalFlowState holds the value of the "local_flow_state" field. + LocalFlowState map[string]interface{} `json:"local_flow_state,omitempty"` + // BrowserSessionKey holds the value of the "browser_session_key" field. + BrowserSessionKey string `json:"browser_session_key,omitempty"` + // CompletionCodeHash holds the value of the "completion_code_hash" field. + CompletionCodeHash string `json:"completion_code_hash,omitempty"` + // CompletionCodeExpiresAt holds the value of the "completion_code_expires_at" field. + CompletionCodeExpiresAt *time.Time `json:"completion_code_expires_at,omitempty"` + // EmailVerifiedAt holds the value of the "email_verified_at" field. + EmailVerifiedAt *time.Time `json:"email_verified_at,omitempty"` + // PasswordVerifiedAt holds the value of the "password_verified_at" field. + PasswordVerifiedAt *time.Time `json:"password_verified_at,omitempty"` + // TotpVerifiedAt holds the value of the "totp_verified_at" field. + TotpVerifiedAt *time.Time `json:"totp_verified_at,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + // ConsumedAt holds the value of the "consumed_at" field. + ConsumedAt *time.Time `json:"consumed_at,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the PendingAuthSessionQuery when eager-loading is set. + Edges PendingAuthSessionEdges `json:"edges"` + selectValues sql.SelectValues +} + +// PendingAuthSessionEdges holds the relations/edges for other nodes in the graph. +type PendingAuthSessionEdges struct { + // TargetUser holds the value of the target_user edge. + TargetUser *User `json:"target_user,omitempty"` + // AdoptionDecision holds the value of the adoption_decision edge. + AdoptionDecision *IdentityAdoptionDecision `json:"adoption_decision,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// TargetUserOrErr returns the TargetUser value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) TargetUserOrErr() (*User, error) { + if e.TargetUser != nil { + return e.TargetUser, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "target_user"} +} + +// AdoptionDecisionOrErr returns the AdoptionDecision value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e PendingAuthSessionEdges) AdoptionDecisionOrErr() (*IdentityAdoptionDecision, error) { + if e.AdoptionDecision != nil { + return e.AdoptionDecision, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: identityadoptiondecision.Label} + } + return nil, &NotLoadedError{edge: "adoption_decision"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*PendingAuthSession) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldUpstreamIdentityClaims, pendingauthsession.FieldLocalFlowState: + values[i] = new([]byte) + case pendingauthsession.FieldID, pendingauthsession.FieldTargetUserID: + values[i] = new(sql.NullInt64) + case pendingauthsession.FieldSessionToken, pendingauthsession.FieldIntent, pendingauthsession.FieldProviderType, pendingauthsession.FieldProviderKey, pendingauthsession.FieldProviderSubject, pendingauthsession.FieldRedirectTo, pendingauthsession.FieldResolvedEmail, pendingauthsession.FieldRegistrationPasswordHash, pendingauthsession.FieldBrowserSessionKey, pendingauthsession.FieldCompletionCodeHash: + values[i] = new(sql.NullString) + case pendingauthsession.FieldCreatedAt, pendingauthsession.FieldUpdatedAt, pendingauthsession.FieldCompletionCodeExpiresAt, pendingauthsession.FieldEmailVerifiedAt, pendingauthsession.FieldPasswordVerifiedAt, pendingauthsession.FieldTotpVerifiedAt, pendingauthsession.FieldExpiresAt, pendingauthsession.FieldConsumedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the PendingAuthSession fields. +func (_m *PendingAuthSession) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case pendingauthsession.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case pendingauthsession.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case pendingauthsession.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case pendingauthsession.FieldSessionToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field session_token", values[i]) + } else if value.Valid { + _m.SessionToken = value.String + } + case pendingauthsession.FieldIntent: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field intent", values[i]) + } else if value.Valid { + _m.Intent = value.String + } + case pendingauthsession.FieldProviderType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_type", values[i]) + } else if value.Valid { + _m.ProviderType = value.String + } + case pendingauthsession.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = value.String + } + case pendingauthsession.FieldProviderSubject: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_subject", values[i]) + } else if value.Valid { + _m.ProviderSubject = value.String + } + case pendingauthsession.FieldTargetUserID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field target_user_id", values[i]) + } else if value.Valid { + _m.TargetUserID = new(int64) + *_m.TargetUserID = value.Int64 + } + case pendingauthsession.FieldRedirectTo: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field redirect_to", values[i]) + } else if value.Valid { + _m.RedirectTo = value.String + } + case pendingauthsession.FieldResolvedEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field resolved_email", values[i]) + } else if value.Valid { + _m.ResolvedEmail = value.String + } + case pendingauthsession.FieldRegistrationPasswordHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field registration_password_hash", values[i]) + } else if value.Valid { + _m.RegistrationPasswordHash = value.String + } + case pendingauthsession.FieldUpstreamIdentityClaims: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field upstream_identity_claims", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.UpstreamIdentityClaims); err != nil { + return fmt.Errorf("unmarshal field upstream_identity_claims: %w", err) + } + } + case pendingauthsession.FieldLocalFlowState: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field local_flow_state", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.LocalFlowState); err != nil { + return fmt.Errorf("unmarshal field local_flow_state: %w", err) + } + } + case pendingauthsession.FieldBrowserSessionKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field browser_session_key", values[i]) + } else if value.Valid { + _m.BrowserSessionKey = value.String + } + case pendingauthsession.FieldCompletionCodeHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_hash", values[i]) + } else if value.Valid { + _m.CompletionCodeHash = value.String + } + case pendingauthsession.FieldCompletionCodeExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field completion_code_expires_at", values[i]) + } else if value.Valid { + _m.CompletionCodeExpiresAt = new(time.Time) + *_m.CompletionCodeExpiresAt = value.Time + } + case pendingauthsession.FieldEmailVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field email_verified_at", values[i]) + } else if value.Valid { + _m.EmailVerifiedAt = new(time.Time) + *_m.EmailVerifiedAt = value.Time + } + case pendingauthsession.FieldPasswordVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field password_verified_at", values[i]) + } else if value.Valid { + _m.PasswordVerifiedAt = new(time.Time) + *_m.PasswordVerifiedAt = value.Time + } + case pendingauthsession.FieldTotpVerifiedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field totp_verified_at", values[i]) + } else if value.Valid { + _m.TotpVerifiedAt = new(time.Time) + *_m.TotpVerifiedAt = value.Time + } + case pendingauthsession.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + case pendingauthsession.FieldConsumedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field consumed_at", values[i]) + } else if value.Valid { + _m.ConsumedAt = new(time.Time) + *_m.ConsumedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the PendingAuthSession. +// This includes values selected through modifiers, order, etc. +func (_m *PendingAuthSession) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// QueryTargetUser queries the "target_user" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryTargetUser() *UserQuery { + return NewPendingAuthSessionClient(_m.config).QueryTargetUser(_m) +} + +// QueryAdoptionDecision queries the "adoption_decision" edge of the PendingAuthSession entity. +func (_m *PendingAuthSession) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + return NewPendingAuthSessionClient(_m.config).QueryAdoptionDecision(_m) +} + +// Update returns a builder for updating this PendingAuthSession. +// Note that you need to call PendingAuthSession.Unwrap() before calling this method if this PendingAuthSession +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *PendingAuthSession) Update() *PendingAuthSessionUpdateOne { + return NewPendingAuthSessionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the PendingAuthSession entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *PendingAuthSession) Unwrap() *PendingAuthSession { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: PendingAuthSession is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *PendingAuthSession) String() string { + var builder strings.Builder + builder.WriteString("PendingAuthSession(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("session_token=") + builder.WriteString(_m.SessionToken) + builder.WriteString(", ") + builder.WriteString("intent=") + builder.WriteString(_m.Intent) + builder.WriteString(", ") + builder.WriteString("provider_type=") + builder.WriteString(_m.ProviderType) + builder.WriteString(", ") + builder.WriteString("provider_key=") + builder.WriteString(_m.ProviderKey) + builder.WriteString(", ") + builder.WriteString("provider_subject=") + builder.WriteString(_m.ProviderSubject) + builder.WriteString(", ") + if v := _m.TargetUserID; v != nil { + builder.WriteString("target_user_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("redirect_to=") + builder.WriteString(_m.RedirectTo) + builder.WriteString(", ") + builder.WriteString("resolved_email=") + builder.WriteString(_m.ResolvedEmail) + builder.WriteString(", ") + builder.WriteString("registration_password_hash=") + builder.WriteString(_m.RegistrationPasswordHash) + builder.WriteString(", ") + builder.WriteString("upstream_identity_claims=") + builder.WriteString(fmt.Sprintf("%v", _m.UpstreamIdentityClaims)) + builder.WriteString(", ") + builder.WriteString("local_flow_state=") + builder.WriteString(fmt.Sprintf("%v", _m.LocalFlowState)) + builder.WriteString(", ") + builder.WriteString("browser_session_key=") + builder.WriteString(_m.BrowserSessionKey) + builder.WriteString(", ") + builder.WriteString("completion_code_hash=") + builder.WriteString(_m.CompletionCodeHash) + builder.WriteString(", ") + if v := _m.CompletionCodeExpiresAt; v != nil { + builder.WriteString("completion_code_expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.EmailVerifiedAt; v != nil { + builder.WriteString("email_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.PasswordVerifiedAt; v != nil { + builder.WriteString("password_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TotpVerifiedAt; v != nil { + builder.WriteString("totp_verified_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteString(", ") + if v := _m.ConsumedAt; v != nil { + builder.WriteString("consumed_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// PendingAuthSessions is a parsable slice of PendingAuthSession. +type PendingAuthSessions []*PendingAuthSession diff --git a/backend/ent/pendingauthsession/pendingauthsession.go b/backend/ent/pendingauthsession/pendingauthsession.go new file mode 100644 index 00000000..8a3ac9bf --- /dev/null +++ b/backend/ent/pendingauthsession/pendingauthsession.go @@ -0,0 +1,279 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + +const ( + // Label holds the string label denoting the pendingauthsession type in the database. + Label = "pending_auth_session" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldSessionToken holds the string denoting the session_token field in the database. + FieldSessionToken = "session_token" + // FieldIntent holds the string denoting the intent field in the database. + FieldIntent = "intent" + // FieldProviderType holds the string denoting the provider_type field in the database. + FieldProviderType = "provider_type" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" + // FieldProviderSubject holds the string denoting the provider_subject field in the database. + FieldProviderSubject = "provider_subject" + // FieldTargetUserID holds the string denoting the target_user_id field in the database. + FieldTargetUserID = "target_user_id" + // FieldRedirectTo holds the string denoting the redirect_to field in the database. + FieldRedirectTo = "redirect_to" + // FieldResolvedEmail holds the string denoting the resolved_email field in the database. + FieldResolvedEmail = "resolved_email" + // FieldRegistrationPasswordHash holds the string denoting the registration_password_hash field in the database. + FieldRegistrationPasswordHash = "registration_password_hash" + // FieldUpstreamIdentityClaims holds the string denoting the upstream_identity_claims field in the database. + FieldUpstreamIdentityClaims = "upstream_identity_claims" + // FieldLocalFlowState holds the string denoting the local_flow_state field in the database. + FieldLocalFlowState = "local_flow_state" + // FieldBrowserSessionKey holds the string denoting the browser_session_key field in the database. + FieldBrowserSessionKey = "browser_session_key" + // FieldCompletionCodeHash holds the string denoting the completion_code_hash field in the database. + FieldCompletionCodeHash = "completion_code_hash" + // FieldCompletionCodeExpiresAt holds the string denoting the completion_code_expires_at field in the database. + FieldCompletionCodeExpiresAt = "completion_code_expires_at" + // FieldEmailVerifiedAt holds the string denoting the email_verified_at field in the database. + FieldEmailVerifiedAt = "email_verified_at" + // FieldPasswordVerifiedAt holds the string denoting the password_verified_at field in the database. + FieldPasswordVerifiedAt = "password_verified_at" + // FieldTotpVerifiedAt holds the string denoting the totp_verified_at field in the database. + FieldTotpVerifiedAt = "totp_verified_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // FieldConsumedAt holds the string denoting the consumed_at field in the database. + FieldConsumedAt = "consumed_at" + // EdgeTargetUser holds the string denoting the target_user edge name in mutations. + EdgeTargetUser = "target_user" + // EdgeAdoptionDecision holds the string denoting the adoption_decision edge name in mutations. + EdgeAdoptionDecision = "adoption_decision" + // Table holds the table name of the pendingauthsession in the database. + Table = "pending_auth_sessions" + // TargetUserTable is the table that holds the target_user relation/edge. + TargetUserTable = "pending_auth_sessions" + // TargetUserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + TargetUserInverseTable = "users" + // TargetUserColumn is the table column denoting the target_user relation/edge. + TargetUserColumn = "target_user_id" + // AdoptionDecisionTable is the table that holds the adoption_decision relation/edge. + AdoptionDecisionTable = "identity_adoption_decisions" + // AdoptionDecisionInverseTable is the table name for the IdentityAdoptionDecision entity. + // It exists in this package in order to avoid circular dependency with the "identityadoptiondecision" package. + AdoptionDecisionInverseTable = "identity_adoption_decisions" + // AdoptionDecisionColumn is the table column denoting the adoption_decision relation/edge. + AdoptionDecisionColumn = "pending_auth_session_id" +) + +// Columns holds all SQL columns for pendingauthsession fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldSessionToken, + FieldIntent, + FieldProviderType, + FieldProviderKey, + FieldProviderSubject, + FieldTargetUserID, + FieldRedirectTo, + FieldResolvedEmail, + FieldRegistrationPasswordHash, + FieldUpstreamIdentityClaims, + FieldLocalFlowState, + FieldBrowserSessionKey, + FieldCompletionCodeHash, + FieldCompletionCodeExpiresAt, + FieldEmailVerifiedAt, + FieldPasswordVerifiedAt, + FieldTotpVerifiedAt, + FieldExpiresAt, + FieldConsumedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + SessionTokenValidator func(string) error + // IntentValidator is a validator for the "intent" field. It is called by the builders before save. + IntentValidator func(string) error + // ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + ProviderTypeValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error + // ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + ProviderSubjectValidator func(string) error + // DefaultRedirectTo holds the default value on creation for the "redirect_to" field. + DefaultRedirectTo string + // DefaultResolvedEmail holds the default value on creation for the "resolved_email" field. + DefaultResolvedEmail string + // DefaultRegistrationPasswordHash holds the default value on creation for the "registration_password_hash" field. + DefaultRegistrationPasswordHash string + // DefaultUpstreamIdentityClaims holds the default value on creation for the "upstream_identity_claims" field. + DefaultUpstreamIdentityClaims func() map[string]interface{} + // DefaultLocalFlowState holds the default value on creation for the "local_flow_state" field. + DefaultLocalFlowState func() map[string]interface{} + // DefaultBrowserSessionKey holds the default value on creation for the "browser_session_key" field. + DefaultBrowserSessionKey string + // DefaultCompletionCodeHash holds the default value on creation for the "completion_code_hash" field. + DefaultCompletionCodeHash string +) + +// OrderOption defines the ordering options for the PendingAuthSession queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// BySessionToken orders the results by the session_token field. +func BySessionToken(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSessionToken, opts...).ToFunc() +} + +// ByIntent orders the results by the intent field. +func ByIntent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIntent, opts...).ToFunc() +} + +// ByProviderType orders the results by the provider_type field. +func ByProviderType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderType, opts...).ToFunc() +} + +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + +// ByProviderSubject orders the results by the provider_subject field. +func ByProviderSubject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderSubject, opts...).ToFunc() +} + +// ByTargetUserID orders the results by the target_user_id field. +func ByTargetUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTargetUserID, opts...).ToFunc() +} + +// ByRedirectTo orders the results by the redirect_to field. +func ByRedirectTo(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRedirectTo, opts...).ToFunc() +} + +// ByResolvedEmail orders the results by the resolved_email field. +func ByResolvedEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResolvedEmail, opts...).ToFunc() +} + +// ByRegistrationPasswordHash orders the results by the registration_password_hash field. +func ByRegistrationPasswordHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRegistrationPasswordHash, opts...).ToFunc() +} + +// ByBrowserSessionKey orders the results by the browser_session_key field. +func ByBrowserSessionKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBrowserSessionKey, opts...).ToFunc() +} + +// ByCompletionCodeHash orders the results by the completion_code_hash field. +func ByCompletionCodeHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeHash, opts...).ToFunc() +} + +// ByCompletionCodeExpiresAt orders the results by the completion_code_expires_at field. +func ByCompletionCodeExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCompletionCodeExpiresAt, opts...).ToFunc() +} + +// ByEmailVerifiedAt orders the results by the email_verified_at field. +func ByEmailVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEmailVerifiedAt, opts...).ToFunc() +} + +// ByPasswordVerifiedAt orders the results by the password_verified_at field. +func ByPasswordVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPasswordVerifiedAt, opts...).ToFunc() +} + +// ByTotpVerifiedAt orders the results by the totp_verified_at field. +func ByTotpVerifiedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpVerifiedAt, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + +// ByConsumedAt orders the results by the consumed_at field. +func ByConsumedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldConsumedAt, opts...).ToFunc() +} + +// ByTargetUserField orders the results by target_user field. +func ByTargetUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newTargetUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByAdoptionDecisionField orders the results by adoption_decision field. +func ByAdoptionDecisionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAdoptionDecisionStep(), sql.OrderByField(field, opts...)) + } +} +func newTargetUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(TargetUserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) +} +func newAdoptionDecisionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AdoptionDecisionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) +} diff --git a/backend/ent/pendingauthsession/where.go b/backend/ent/pendingauthsession/where.go new file mode 100644 index 00000000..cb316f47 --- /dev/null +++ b/backend/ent/pendingauthsession/where.go @@ -0,0 +1,1262 @@ +// Code generated by ent, DO NOT EDIT. + +package pendingauthsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// SessionToken applies equality check predicate on the "session_token" field. It's identical to SessionTokenEQ. +func SessionToken(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// Intent applies equality check predicate on the "intent" field. It's identical to IntentEQ. +func Intent(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// ProviderType applies equality check predicate on the "provider_type" field. It's identical to ProviderTypeEQ. +func ProviderType(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderSubject applies equality check predicate on the "provider_subject" field. It's identical to ProviderSubjectEQ. +func ProviderSubject(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// TargetUserID applies equality check predicate on the "target_user_id" field. It's identical to TargetUserIDEQ. +func TargetUserID(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// RedirectTo applies equality check predicate on the "redirect_to" field. It's identical to RedirectToEQ. +func RedirectTo(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// ResolvedEmail applies equality check predicate on the "resolved_email" field. It's identical to ResolvedEmailEQ. +func ResolvedEmail(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHash applies equality check predicate on the "registration_password_hash" field. It's identical to RegistrationPasswordHashEQ. +func RegistrationPasswordHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKey applies equality check predicate on the "browser_session_key" field. It's identical to BrowserSessionKeyEQ. +func BrowserSessionKey(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHash applies equality check predicate on the "completion_code_hash" field. It's identical to CompletionCodeHashEQ. +func CompletionCodeHash(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAt applies equality check predicate on the "completion_code_expires_at" field. It's identical to CompletionCodeExpiresAtEQ. +func CompletionCodeExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// EmailVerifiedAt applies equality check predicate on the "email_verified_at" field. It's identical to EmailVerifiedAtEQ. +func EmailVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// PasswordVerifiedAt applies equality check predicate on the "password_verified_at" field. It's identical to PasswordVerifiedAtEQ. +func PasswordVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// TotpVerifiedAt applies equality check predicate on the "totp_verified_at" field. It's identical to TotpVerifiedAtEQ. +func TotpVerifiedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ConsumedAt applies equality check predicate on the "consumed_at" field. It's identical to ConsumedAtEQ. +func ConsumedAt(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// SessionTokenEQ applies the EQ predicate on the "session_token" field. +func SessionTokenEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldSessionToken, v)) +} + +// SessionTokenNEQ applies the NEQ predicate on the "session_token" field. +func SessionTokenNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldSessionToken, v)) +} + +// SessionTokenIn applies the In predicate on the "session_token" field. +func SessionTokenIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldSessionToken, vs...)) +} + +// SessionTokenNotIn applies the NotIn predicate on the "session_token" field. +func SessionTokenNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldSessionToken, vs...)) +} + +// SessionTokenGT applies the GT predicate on the "session_token" field. +func SessionTokenGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldSessionToken, v)) +} + +// SessionTokenGTE applies the GTE predicate on the "session_token" field. +func SessionTokenGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldSessionToken, v)) +} + +// SessionTokenLT applies the LT predicate on the "session_token" field. +func SessionTokenLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldSessionToken, v)) +} + +// SessionTokenLTE applies the LTE predicate on the "session_token" field. +func SessionTokenLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldSessionToken, v)) +} + +// SessionTokenContains applies the Contains predicate on the "session_token" field. +func SessionTokenContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldSessionToken, v)) +} + +// SessionTokenHasPrefix applies the HasPrefix predicate on the "session_token" field. +func SessionTokenHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldSessionToken, v)) +} + +// SessionTokenHasSuffix applies the HasSuffix predicate on the "session_token" field. +func SessionTokenHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldSessionToken, v)) +} + +// SessionTokenEqualFold applies the EqualFold predicate on the "session_token" field. +func SessionTokenEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldSessionToken, v)) +} + +// SessionTokenContainsFold applies the ContainsFold predicate on the "session_token" field. +func SessionTokenContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldSessionToken, v)) +} + +// IntentEQ applies the EQ predicate on the "intent" field. +func IntentEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldIntent, v)) +} + +// IntentNEQ applies the NEQ predicate on the "intent" field. +func IntentNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldIntent, v)) +} + +// IntentIn applies the In predicate on the "intent" field. +func IntentIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldIntent, vs...)) +} + +// IntentNotIn applies the NotIn predicate on the "intent" field. +func IntentNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldIntent, vs...)) +} + +// IntentGT applies the GT predicate on the "intent" field. +func IntentGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldIntent, v)) +} + +// IntentGTE applies the GTE predicate on the "intent" field. +func IntentGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldIntent, v)) +} + +// IntentLT applies the LT predicate on the "intent" field. +func IntentLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldIntent, v)) +} + +// IntentLTE applies the LTE predicate on the "intent" field. +func IntentLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldIntent, v)) +} + +// IntentContains applies the Contains predicate on the "intent" field. +func IntentContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldIntent, v)) +} + +// IntentHasPrefix applies the HasPrefix predicate on the "intent" field. +func IntentHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldIntent, v)) +} + +// IntentHasSuffix applies the HasSuffix predicate on the "intent" field. +func IntentHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldIntent, v)) +} + +// IntentEqualFold applies the EqualFold predicate on the "intent" field. +func IntentEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldIntent, v)) +} + +// IntentContainsFold applies the ContainsFold predicate on the "intent" field. +func IntentContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldIntent, v)) +} + +// ProviderTypeEQ applies the EQ predicate on the "provider_type" field. +func ProviderTypeEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderType, v)) +} + +// ProviderTypeNEQ applies the NEQ predicate on the "provider_type" field. +func ProviderTypeNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderType, v)) +} + +// ProviderTypeIn applies the In predicate on the "provider_type" field. +func ProviderTypeIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderType, vs...)) +} + +// ProviderTypeNotIn applies the NotIn predicate on the "provider_type" field. +func ProviderTypeNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderType, vs...)) +} + +// ProviderTypeGT applies the GT predicate on the "provider_type" field. +func ProviderTypeGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderType, v)) +} + +// ProviderTypeGTE applies the GTE predicate on the "provider_type" field. +func ProviderTypeGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderType, v)) +} + +// ProviderTypeLT applies the LT predicate on the "provider_type" field. +func ProviderTypeLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderType, v)) +} + +// ProviderTypeLTE applies the LTE predicate on the "provider_type" field. +func ProviderTypeLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderType, v)) +} + +// ProviderTypeContains applies the Contains predicate on the "provider_type" field. +func ProviderTypeContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderType, v)) +} + +// ProviderTypeHasPrefix applies the HasPrefix predicate on the "provider_type" field. +func ProviderTypeHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderType, v)) +} + +// ProviderTypeHasSuffix applies the HasSuffix predicate on the "provider_type" field. +func ProviderTypeHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderType, v)) +} + +// ProviderTypeEqualFold applies the EqualFold predicate on the "provider_type" field. +func ProviderTypeEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderType, v)) +} + +// ProviderTypeContainsFold applies the ContainsFold predicate on the "provider_type" field. +func ProviderTypeContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderType, v)) +} + +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderKey, v)) +} + +// ProviderSubjectEQ applies the EQ predicate on the "provider_subject" field. +func ProviderSubjectEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectNEQ applies the NEQ predicate on the "provider_subject" field. +func ProviderSubjectNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldProviderSubject, v)) +} + +// ProviderSubjectIn applies the In predicate on the "provider_subject" field. +func ProviderSubjectIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectNotIn applies the NotIn predicate on the "provider_subject" field. +func ProviderSubjectNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldProviderSubject, vs...)) +} + +// ProviderSubjectGT applies the GT predicate on the "provider_subject" field. +func ProviderSubjectGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldProviderSubject, v)) +} + +// ProviderSubjectGTE applies the GTE predicate on the "provider_subject" field. +func ProviderSubjectGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldProviderSubject, v)) +} + +// ProviderSubjectLT applies the LT predicate on the "provider_subject" field. +func ProviderSubjectLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldProviderSubject, v)) +} + +// ProviderSubjectLTE applies the LTE predicate on the "provider_subject" field. +func ProviderSubjectLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldProviderSubject, v)) +} + +// ProviderSubjectContains applies the Contains predicate on the "provider_subject" field. +func ProviderSubjectContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldProviderSubject, v)) +} + +// ProviderSubjectHasPrefix applies the HasPrefix predicate on the "provider_subject" field. +func ProviderSubjectHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldProviderSubject, v)) +} + +// ProviderSubjectHasSuffix applies the HasSuffix predicate on the "provider_subject" field. +func ProviderSubjectHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldProviderSubject, v)) +} + +// ProviderSubjectEqualFold applies the EqualFold predicate on the "provider_subject" field. +func ProviderSubjectEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldProviderSubject, v)) +} + +// ProviderSubjectContainsFold applies the ContainsFold predicate on the "provider_subject" field. +func ProviderSubjectContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldProviderSubject, v)) +} + +// TargetUserIDEQ applies the EQ predicate on the "target_user_id" field. +func TargetUserIDEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTargetUserID, v)) +} + +// TargetUserIDNEQ applies the NEQ predicate on the "target_user_id" field. +func TargetUserIDNEQ(v int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTargetUserID, v)) +} + +// TargetUserIDIn applies the In predicate on the "target_user_id" field. +func TargetUserIDIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDNotIn applies the NotIn predicate on the "target_user_id" field. +func TargetUserIDNotIn(vs ...int64) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTargetUserID, vs...)) +} + +// TargetUserIDIsNil applies the IsNil predicate on the "target_user_id" field. +func TargetUserIDIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTargetUserID)) +} + +// TargetUserIDNotNil applies the NotNil predicate on the "target_user_id" field. +func TargetUserIDNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTargetUserID)) +} + +// RedirectToEQ applies the EQ predicate on the "redirect_to" field. +func RedirectToEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRedirectTo, v)) +} + +// RedirectToNEQ applies the NEQ predicate on the "redirect_to" field. +func RedirectToNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRedirectTo, v)) +} + +// RedirectToIn applies the In predicate on the "redirect_to" field. +func RedirectToIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRedirectTo, vs...)) +} + +// RedirectToNotIn applies the NotIn predicate on the "redirect_to" field. +func RedirectToNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRedirectTo, vs...)) +} + +// RedirectToGT applies the GT predicate on the "redirect_to" field. +func RedirectToGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRedirectTo, v)) +} + +// RedirectToGTE applies the GTE predicate on the "redirect_to" field. +func RedirectToGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRedirectTo, v)) +} + +// RedirectToLT applies the LT predicate on the "redirect_to" field. +func RedirectToLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRedirectTo, v)) +} + +// RedirectToLTE applies the LTE predicate on the "redirect_to" field. +func RedirectToLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRedirectTo, v)) +} + +// RedirectToContains applies the Contains predicate on the "redirect_to" field. +func RedirectToContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRedirectTo, v)) +} + +// RedirectToHasPrefix applies the HasPrefix predicate on the "redirect_to" field. +func RedirectToHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRedirectTo, v)) +} + +// RedirectToHasSuffix applies the HasSuffix predicate on the "redirect_to" field. +func RedirectToHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRedirectTo, v)) +} + +// RedirectToEqualFold applies the EqualFold predicate on the "redirect_to" field. +func RedirectToEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRedirectTo, v)) +} + +// RedirectToContainsFold applies the ContainsFold predicate on the "redirect_to" field. +func RedirectToContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRedirectTo, v)) +} + +// ResolvedEmailEQ applies the EQ predicate on the "resolved_email" field. +func ResolvedEmailEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailNEQ applies the NEQ predicate on the "resolved_email" field. +func ResolvedEmailNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldResolvedEmail, v)) +} + +// ResolvedEmailIn applies the In predicate on the "resolved_email" field. +func ResolvedEmailIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailNotIn applies the NotIn predicate on the "resolved_email" field. +func ResolvedEmailNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldResolvedEmail, vs...)) +} + +// ResolvedEmailGT applies the GT predicate on the "resolved_email" field. +func ResolvedEmailGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldResolvedEmail, v)) +} + +// ResolvedEmailGTE applies the GTE predicate on the "resolved_email" field. +func ResolvedEmailGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailLT applies the LT predicate on the "resolved_email" field. +func ResolvedEmailLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldResolvedEmail, v)) +} + +// ResolvedEmailLTE applies the LTE predicate on the "resolved_email" field. +func ResolvedEmailLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldResolvedEmail, v)) +} + +// ResolvedEmailContains applies the Contains predicate on the "resolved_email" field. +func ResolvedEmailContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasPrefix applies the HasPrefix predicate on the "resolved_email" field. +func ResolvedEmailHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldResolvedEmail, v)) +} + +// ResolvedEmailHasSuffix applies the HasSuffix predicate on the "resolved_email" field. +func ResolvedEmailHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldResolvedEmail, v)) +} + +// ResolvedEmailEqualFold applies the EqualFold predicate on the "resolved_email" field. +func ResolvedEmailEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldResolvedEmail, v)) +} + +// ResolvedEmailContainsFold applies the ContainsFold predicate on the "resolved_email" field. +func ResolvedEmailContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldResolvedEmail, v)) +} + +// RegistrationPasswordHashEQ applies the EQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashNEQ applies the NEQ predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashIn applies the In predicate on the "registration_password_hash" field. +func RegistrationPasswordHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashNotIn applies the NotIn predicate on the "registration_password_hash" field. +func RegistrationPasswordHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldRegistrationPasswordHash, vs...)) +} + +// RegistrationPasswordHashGT applies the GT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashGTE applies the GTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLT applies the LT predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashLTE applies the LTE predicate on the "registration_password_hash" field. +func RegistrationPasswordHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContains applies the Contains predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasPrefix applies the HasPrefix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashHasSuffix applies the HasSuffix predicate on the "registration_password_hash" field. +func RegistrationPasswordHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashEqualFold applies the EqualFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldRegistrationPasswordHash, v)) +} + +// RegistrationPasswordHashContainsFold applies the ContainsFold predicate on the "registration_password_hash" field. +func RegistrationPasswordHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldRegistrationPasswordHash, v)) +} + +// BrowserSessionKeyEQ applies the EQ predicate on the "browser_session_key" field. +func BrowserSessionKeyEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyNEQ applies the NEQ predicate on the "browser_session_key" field. +func BrowserSessionKeyNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyIn applies the In predicate on the "browser_session_key" field. +func BrowserSessionKeyIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyNotIn applies the NotIn predicate on the "browser_session_key" field. +func BrowserSessionKeyNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldBrowserSessionKey, vs...)) +} + +// BrowserSessionKeyGT applies the GT predicate on the "browser_session_key" field. +func BrowserSessionKeyGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyGTE applies the GTE predicate on the "browser_session_key" field. +func BrowserSessionKeyGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLT applies the LT predicate on the "browser_session_key" field. +func BrowserSessionKeyLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyLTE applies the LTE predicate on the "browser_session_key" field. +func BrowserSessionKeyLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContains applies the Contains predicate on the "browser_session_key" field. +func BrowserSessionKeyContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasPrefix applies the HasPrefix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyHasSuffix applies the HasSuffix predicate on the "browser_session_key" field. +func BrowserSessionKeyHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyEqualFold applies the EqualFold predicate on the "browser_session_key" field. +func BrowserSessionKeyEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldBrowserSessionKey, v)) +} + +// BrowserSessionKeyContainsFold applies the ContainsFold predicate on the "browser_session_key" field. +func BrowserSessionKeyContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldBrowserSessionKey, v)) +} + +// CompletionCodeHashEQ applies the EQ predicate on the "completion_code_hash" field. +func CompletionCodeHashEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashNEQ applies the NEQ predicate on the "completion_code_hash" field. +func CompletionCodeHashNEQ(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashIn applies the In predicate on the "completion_code_hash" field. +func CompletionCodeHashIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashNotIn applies the NotIn predicate on the "completion_code_hash" field. +func CompletionCodeHashNotIn(vs ...string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeHash, vs...)) +} + +// CompletionCodeHashGT applies the GT predicate on the "completion_code_hash" field. +func CompletionCodeHashGT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashGTE applies the GTE predicate on the "completion_code_hash" field. +func CompletionCodeHashGTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLT applies the LT predicate on the "completion_code_hash" field. +func CompletionCodeHashLT(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashLTE applies the LTE predicate on the "completion_code_hash" field. +func CompletionCodeHashLTE(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContains applies the Contains predicate on the "completion_code_hash" field. +func CompletionCodeHashContains(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContains(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasPrefix applies the HasPrefix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasPrefix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasPrefix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashHasSuffix applies the HasSuffix predicate on the "completion_code_hash" field. +func CompletionCodeHashHasSuffix(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldHasSuffix(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashEqualFold applies the EqualFold predicate on the "completion_code_hash" field. +func CompletionCodeHashEqualFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEqualFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeHashContainsFold applies the ContainsFold predicate on the "completion_code_hash" field. +func CompletionCodeHashContainsFold(v string) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldContainsFold(FieldCompletionCodeHash, v)) +} + +// CompletionCodeExpiresAtEQ applies the EQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtNEQ applies the NEQ predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIn applies the In predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtNotIn applies the NotIn predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldCompletionCodeExpiresAt, vs...)) +} + +// CompletionCodeExpiresAtGT applies the GT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtGTE applies the GTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLT applies the LT predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtLTE applies the LTE predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldCompletionCodeExpiresAt, v)) +} + +// CompletionCodeExpiresAtIsNil applies the IsNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldCompletionCodeExpiresAt)) +} + +// CompletionCodeExpiresAtNotNil applies the NotNil predicate on the "completion_code_expires_at" field. +func CompletionCodeExpiresAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldCompletionCodeExpiresAt)) +} + +// EmailVerifiedAtEQ applies the EQ predicate on the "email_verified_at" field. +func EmailVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtNEQ applies the NEQ predicate on the "email_verified_at" field. +func EmailVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIn applies the In predicate on the "email_verified_at" field. +func EmailVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtNotIn applies the NotIn predicate on the "email_verified_at" field. +func EmailVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldEmailVerifiedAt, vs...)) +} + +// EmailVerifiedAtGT applies the GT predicate on the "email_verified_at" field. +func EmailVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtGTE applies the GTE predicate on the "email_verified_at" field. +func EmailVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLT applies the LT predicate on the "email_verified_at" field. +func EmailVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtLTE applies the LTE predicate on the "email_verified_at" field. +func EmailVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldEmailVerifiedAt, v)) +} + +// EmailVerifiedAtIsNil applies the IsNil predicate on the "email_verified_at" field. +func EmailVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldEmailVerifiedAt)) +} + +// EmailVerifiedAtNotNil applies the NotNil predicate on the "email_verified_at" field. +func EmailVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldEmailVerifiedAt)) +} + +// PasswordVerifiedAtEQ applies the EQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtNEQ applies the NEQ predicate on the "password_verified_at" field. +func PasswordVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIn applies the In predicate on the "password_verified_at" field. +func PasswordVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtNotIn applies the NotIn predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldPasswordVerifiedAt, vs...)) +} + +// PasswordVerifiedAtGT applies the GT predicate on the "password_verified_at" field. +func PasswordVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtGTE applies the GTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLT applies the LT predicate on the "password_verified_at" field. +func PasswordVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtLTE applies the LTE predicate on the "password_verified_at" field. +func PasswordVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldPasswordVerifiedAt, v)) +} + +// PasswordVerifiedAtIsNil applies the IsNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldPasswordVerifiedAt)) +} + +// PasswordVerifiedAtNotNil applies the NotNil predicate on the "password_verified_at" field. +func PasswordVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldPasswordVerifiedAt)) +} + +// TotpVerifiedAtEQ applies the EQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtNEQ applies the NEQ predicate on the "totp_verified_at" field. +func TotpVerifiedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIn applies the In predicate on the "totp_verified_at" field. +func TotpVerifiedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtNotIn applies the NotIn predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldTotpVerifiedAt, vs...)) +} + +// TotpVerifiedAtGT applies the GT predicate on the "totp_verified_at" field. +func TotpVerifiedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtGTE applies the GTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLT applies the LT predicate on the "totp_verified_at" field. +func TotpVerifiedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtLTE applies the LTE predicate on the "totp_verified_at" field. +func TotpVerifiedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldTotpVerifiedAt, v)) +} + +// TotpVerifiedAtIsNil applies the IsNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldTotpVerifiedAt)) +} + +// TotpVerifiedAtNotNil applies the NotNil predicate on the "totp_verified_at" field. +func TotpVerifiedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldTotpVerifiedAt)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ConsumedAtEQ applies the EQ predicate on the "consumed_at" field. +func ConsumedAtEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldEQ(FieldConsumedAt, v)) +} + +// ConsumedAtNEQ applies the NEQ predicate on the "consumed_at" field. +func ConsumedAtNEQ(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNEQ(FieldConsumedAt, v)) +} + +// ConsumedAtIn applies the In predicate on the "consumed_at" field. +func ConsumedAtIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtNotIn applies the NotIn predicate on the "consumed_at" field. +func ConsumedAtNotIn(vs ...time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotIn(FieldConsumedAt, vs...)) +} + +// ConsumedAtGT applies the GT predicate on the "consumed_at" field. +func ConsumedAtGT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGT(FieldConsumedAt, v)) +} + +// ConsumedAtGTE applies the GTE predicate on the "consumed_at" field. +func ConsumedAtGTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldGTE(FieldConsumedAt, v)) +} + +// ConsumedAtLT applies the LT predicate on the "consumed_at" field. +func ConsumedAtLT(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLT(FieldConsumedAt, v)) +} + +// ConsumedAtLTE applies the LTE predicate on the "consumed_at" field. +func ConsumedAtLTE(v time.Time) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldLTE(FieldConsumedAt, v)) +} + +// ConsumedAtIsNil applies the IsNil predicate on the "consumed_at" field. +func ConsumedAtIsNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldIsNull(FieldConsumedAt)) +} + +// ConsumedAtNotNil applies the NotNil predicate on the "consumed_at" field. +func ConsumedAtNotNil() predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.FieldNotNull(FieldConsumedAt)) +} + +// HasTargetUser applies the HasEdge predicate on the "target_user" edge. +func HasTargetUser() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, TargetUserTable, TargetUserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasTargetUserWith applies the HasEdge predicate on the "target_user" edge with a given conditions (other predicates). +func HasTargetUserWith(preds ...predicate.User) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newTargetUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasAdoptionDecision applies the HasEdge predicate on the "adoption_decision" edge. +func HasAdoptionDecision() predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AdoptionDecisionTable, AdoptionDecisionColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAdoptionDecisionWith applies the HasEdge predicate on the "adoption_decision" edge with a given conditions (other predicates). +func HasAdoptionDecisionWith(preds ...predicate.IdentityAdoptionDecision) predicate.PendingAuthSession { + return predicate.PendingAuthSession(func(s *sql.Selector) { + step := newAdoptionDecisionStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.PendingAuthSession) predicate.PendingAuthSession { + return predicate.PendingAuthSession(sql.NotPredicates(p)) +} diff --git a/backend/ent/pendingauthsession_create.go b/backend/ent/pendingauthsession_create.go new file mode 100644 index 00000000..60276daa --- /dev/null +++ b/backend/ent/pendingauthsession_create.go @@ -0,0 +1,1815 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionCreate is the builder for creating a PendingAuthSession entity. +type PendingAuthSessionCreate struct { + config + mutation *PendingAuthSessionMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *PendingAuthSessionCreate) SetCreatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCreatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *PendingAuthSessionCreate) SetUpdatedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableUpdatedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetSessionToken sets the "session_token" field. +func (_c *PendingAuthSessionCreate) SetSessionToken(v string) *PendingAuthSessionCreate { + _c.mutation.SetSessionToken(v) + return _c +} + +// SetIntent sets the "intent" field. +func (_c *PendingAuthSessionCreate) SetIntent(v string) *PendingAuthSessionCreate { + _c.mutation.SetIntent(v) + return _c +} + +// SetProviderType sets the "provider_type" field. +func (_c *PendingAuthSessionCreate) SetProviderType(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderType(v) + return _c +} + +// SetProviderKey sets the "provider_key" field. +func (_c *PendingAuthSessionCreate) SetProviderKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetProviderSubject sets the "provider_subject" field. +func (_c *PendingAuthSessionCreate) SetProviderSubject(v string) *PendingAuthSessionCreate { + _c.mutation.SetProviderSubject(v) + return _c +} + +// SetTargetUserID sets the "target_user_id" field. +func (_c *PendingAuthSessionCreate) SetTargetUserID(v int64) *PendingAuthSessionCreate { + _c.mutation.SetTargetUserID(v) + return _c +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTargetUserID(v *int64) *PendingAuthSessionCreate { + if v != nil { + _c.SetTargetUserID(*v) + } + return _c +} + +// SetRedirectTo sets the "redirect_to" field. +func (_c *PendingAuthSessionCreate) SetRedirectTo(v string) *PendingAuthSessionCreate { + _c.mutation.SetRedirectTo(v) + return _c +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRedirectTo(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRedirectTo(*v) + } + return _c +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_c *PendingAuthSessionCreate) SetResolvedEmail(v string) *PendingAuthSessionCreate { + _c.mutation.SetResolvedEmail(v) + return _c +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableResolvedEmail(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetResolvedEmail(*v) + } + return _c +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_c *PendingAuthSessionCreate) SetRegistrationPasswordHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetRegistrationPasswordHash(v) + return _c +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetRegistrationPasswordHash(*v) + } + return _c +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_c *PendingAuthSessionCreate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetUpstreamIdentityClaims(v) + return _c +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_c *PendingAuthSessionCreate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionCreate { + _c.mutation.SetLocalFlowState(v) + return _c +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_c *PendingAuthSessionCreate) SetBrowserSessionKey(v string) *PendingAuthSessionCreate { + _c.mutation.SetBrowserSessionKey(v) + return _c +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetBrowserSessionKey(*v) + } + return _c +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeHash(v string) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeHash(v) + return _c +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeHash(*v) + } + return _c +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_c *PendingAuthSessionCreate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetCompletionCodeExpiresAt(v) + return _c +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetCompletionCodeExpiresAt(*v) + } + return _c +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_c *PendingAuthSessionCreate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetEmailVerifiedAt(v) + return _c +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetEmailVerifiedAt(*v) + } + return _c +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_c *PendingAuthSessionCreate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetPasswordVerifiedAt(v) + return _c +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetPasswordVerifiedAt(*v) + } + return _c +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_c *PendingAuthSessionCreate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetTotpVerifiedAt(v) + return _c +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetTotpVerifiedAt(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *PendingAuthSessionCreate) SetExpiresAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetConsumedAt sets the "consumed_at" field. +func (_c *PendingAuthSessionCreate) SetConsumedAt(v time.Time) *PendingAuthSessionCreate { + _c.mutation.SetConsumedAt(v) + return _c +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionCreate { + if v != nil { + _c.SetConsumedAt(*v) + } + return _c +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_c *PendingAuthSessionCreate) SetTargetUser(v *User) *PendingAuthSessionCreate { + return _c.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_c *PendingAuthSessionCreate) SetAdoptionDecisionID(id int64) *PendingAuthSessionCreate { + _c.mutation.SetAdoptionDecisionID(id) + return _c +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_c *PendingAuthSessionCreate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionCreate { + if id != nil { + _c = _c.SetAdoptionDecisionID(*id) + } + return _c +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_c *PendingAuthSessionCreate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionCreate { + return _c.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_c *PendingAuthSessionCreate) Mutation() *PendingAuthSessionMutation { + return _c.mutation +} + +// Save creates the PendingAuthSession in the database. +func (_c *PendingAuthSessionCreate) Save(ctx context.Context) (*PendingAuthSession, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *PendingAuthSessionCreate) SaveX(ctx context.Context) *PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *PendingAuthSessionCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := pendingauthsession.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := pendingauthsession.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.RedirectTo(); !ok { + v := pendingauthsession.DefaultRedirectTo + _c.mutation.SetRedirectTo(v) + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + v := pendingauthsession.DefaultResolvedEmail + _c.mutation.SetResolvedEmail(v) + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + v := pendingauthsession.DefaultRegistrationPasswordHash + _c.mutation.SetRegistrationPasswordHash(v) + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + v := pendingauthsession.DefaultUpstreamIdentityClaims() + _c.mutation.SetUpstreamIdentityClaims(v) + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + v := pendingauthsession.DefaultLocalFlowState() + _c.mutation.SetLocalFlowState(v) + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + v := pendingauthsession.DefaultBrowserSessionKey + _c.mutation.SetBrowserSessionKey(v) + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + v := pendingauthsession.DefaultCompletionCodeHash + _c.mutation.SetCompletionCodeHash(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *PendingAuthSessionCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PendingAuthSession.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "PendingAuthSession.updated_at"`)} + } + if _, ok := _c.mutation.SessionToken(); !ok { + return &ValidationError{Name: "session_token", err: errors.New(`ent: missing required field "PendingAuthSession.session_token"`)} + } + if v, ok := _c.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if _, ok := _c.mutation.Intent(); !ok { + return &ValidationError{Name: "intent", err: errors.New(`ent: missing required field "PendingAuthSession.intent"`)} + } + if v, ok := _c.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderType(); !ok { + return &ValidationError{Name: "provider_type", err: errors.New(`ent: missing required field "PendingAuthSession.provider_type"`)} + } + if v, ok := _c.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderKey(); !ok { + return &ValidationError{Name: "provider_key", err: errors.New(`ent: missing required field "PendingAuthSession.provider_key"`)} + } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if _, ok := _c.mutation.ProviderSubject(); !ok { + return &ValidationError{Name: "provider_subject", err: errors.New(`ent: missing required field "PendingAuthSession.provider_subject"`)} + } + if v, ok := _c.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + if _, ok := _c.mutation.RedirectTo(); !ok { + return &ValidationError{Name: "redirect_to", err: errors.New(`ent: missing required field "PendingAuthSession.redirect_to"`)} + } + if _, ok := _c.mutation.ResolvedEmail(); !ok { + return &ValidationError{Name: "resolved_email", err: errors.New(`ent: missing required field "PendingAuthSession.resolved_email"`)} + } + if _, ok := _c.mutation.RegistrationPasswordHash(); !ok { + return &ValidationError{Name: "registration_password_hash", err: errors.New(`ent: missing required field "PendingAuthSession.registration_password_hash"`)} + } + if _, ok := _c.mutation.UpstreamIdentityClaims(); !ok { + return &ValidationError{Name: "upstream_identity_claims", err: errors.New(`ent: missing required field "PendingAuthSession.upstream_identity_claims"`)} + } + if _, ok := _c.mutation.LocalFlowState(); !ok { + return &ValidationError{Name: "local_flow_state", err: errors.New(`ent: missing required field "PendingAuthSession.local_flow_state"`)} + } + if _, ok := _c.mutation.BrowserSessionKey(); !ok { + return &ValidationError{Name: "browser_session_key", err: errors.New(`ent: missing required field "PendingAuthSession.browser_session_key"`)} + } + if _, ok := _c.mutation.CompletionCodeHash(); !ok { + return &ValidationError{Name: "completion_code_hash", err: errors.New(`ent: missing required field "PendingAuthSession.completion_code_hash"`)} + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "PendingAuthSession.expires_at"`)} + } + return nil +} + +func (_c *PendingAuthSessionCreate) sqlSave(ctx context.Context) (*PendingAuthSession, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *PendingAuthSessionCreate) createSpec() (*PendingAuthSession, *sqlgraph.CreateSpec) { + var ( + _node = &PendingAuthSession{config: _c.config} + _spec = sqlgraph.NewCreateSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(pendingauthsession.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + _node.SessionToken = value + } + if value, ok := _c.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + _node.Intent = value + } + if value, ok := _c.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + _node.ProviderType = value + } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = value + } + if value, ok := _c.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + _node.ProviderSubject = value + } + if value, ok := _c.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + _node.RedirectTo = value + } + if value, ok := _c.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + _node.ResolvedEmail = value + } + if value, ok := _c.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + _node.RegistrationPasswordHash = value + } + if value, ok := _c.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + _node.UpstreamIdentityClaims = value + } + if value, ok := _c.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + _node.LocalFlowState = value + } + if value, ok := _c.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + _node.BrowserSessionKey = value + } + if value, ok := _c.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + _node.CompletionCodeHash = value + } + if value, ok := _c.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + _node.CompletionCodeExpiresAt = &value + } + if value, ok := _c.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + _node.EmailVerifiedAt = &value + } + if value, ok := _c.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + _node.PasswordVerifiedAt = &value + } + if value, ok := _c.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + _node.TotpVerifiedAt = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + if value, ok := _c.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + _node.ConsumedAt = &value + } + if nodes := _c.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.TargetUserID = &nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertOne { + _c.conflict = opts + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreate) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertOne{ + create: _c, + } +} + +type ( + // PendingAuthSessionUpsertOne is the builder for "upsert"-ing + // one PendingAuthSession node. + PendingAuthSessionUpsertOne struct { + create *PendingAuthSessionCreate + } + + // PendingAuthSessionUpsert is the "OnConflict" setter. + PendingAuthSessionUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsert) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpdatedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpdatedAt) + return u +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsert) SetSessionToken(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldSessionToken, v) + return u +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateSessionToken() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldSessionToken) + return u +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsert) SetIntent(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldIntent, v) + return u +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateIntent() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldIntent) + return u +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsert) SetProviderType(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderType, v) + return u +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderType() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderType) + return u +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsert) SetProviderKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderKey) + return u +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsert) SetProviderSubject(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldProviderSubject, v) + return u +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateProviderSubject() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldProviderSubject) + return u +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsert) SetTargetUserID(v int64) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTargetUserID, v) + return u +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTargetUserID() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTargetUserID) + return u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsert) ClearTargetUserID() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTargetUserID) + return u +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsert) SetRedirectTo(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRedirectTo, v) + return u +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRedirectTo() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRedirectTo) + return u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsert) SetResolvedEmail(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldResolvedEmail, v) + return u +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateResolvedEmail() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldResolvedEmail) + return u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsert) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldRegistrationPasswordHash, v) + return u +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldRegistrationPasswordHash) + return u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsert) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldUpstreamIdentityClaims, v) + return u +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldUpstreamIdentityClaims) + return u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsert) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldLocalFlowState, v) + return u +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateLocalFlowState() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldLocalFlowState) + return u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsert) SetBrowserSessionKey(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldBrowserSessionKey, v) + return u +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateBrowserSessionKey() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldBrowserSessionKey) + return u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeHash(v string) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeHash, v) + return u +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeHash() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeHash) + return u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldCompletionCodeExpiresAt, v) + return u +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsert) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldCompletionCodeExpiresAt) + return u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldEmailVerifiedAt, v) + return u +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearEmailVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldEmailVerifiedAt) + return u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldPasswordVerifiedAt, v) + return u +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearPasswordVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldPasswordVerifiedAt) + return u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldTotpVerifiedAt, v) + return u +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsert) ClearTotpVerifiedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldTotpVerifiedAt) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsert) SetExpiresAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateExpiresAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldExpiresAt) + return u +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsert) SetConsumedAt(v time.Time) *PendingAuthSessionUpsert { + u.Set(pendingauthsession.FieldConsumedAt, v) + return u +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsert) UpdateConsumedAt() *PendingAuthSessionUpsert { + u.SetExcluded(pendingauthsession.FieldConsumedAt) + return u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsert) ClearConsumedAt() *PendingAuthSessionUpsert { + u.SetNull(pendingauthsession.FieldConsumedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) UpdateNewValues() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertOne) Ignore() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertOne) DoNothing() *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreate.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertOne) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpdatedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertOne) SetSessionToken(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateSessionToken() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertOne) SetIntent(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateIntent() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertOne) SetProviderType(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderType() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertOne) SetProviderKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertOne) SetProviderSubject(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateProviderSubject() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) SetTargetUserID(v int64) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertOne) ClearTargetUserID() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertOne) SetRedirectTo(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRedirectTo() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertOne) SetResolvedEmail(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateResolvedEmail() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateLocalFlowState() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateBrowserSessionKey() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeHash() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearEmailVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertOne) ClearTotpVerifiedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateExpiresAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertOne) UpdateConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertOne) ClearConsumedAt() *PendingAuthSessionUpsertOne { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *PendingAuthSessionUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *PendingAuthSessionUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// PendingAuthSessionCreateBulk is the builder for creating many PendingAuthSession entities in bulk. +type PendingAuthSessionCreateBulk struct { + config + err error + builders []*PendingAuthSessionCreate + conflict []sql.ConflictOption +} + +// Save creates the PendingAuthSession entities in the database. +func (_c *PendingAuthSessionCreateBulk) Save(ctx context.Context) ([]*PendingAuthSession, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*PendingAuthSession, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*PendingAuthSessionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) SaveX(ctx context.Context) []*PendingAuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *PendingAuthSessionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *PendingAuthSessionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.PendingAuthSession.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.PendingAuthSessionUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflict(opts ...sql.ConflictOption) *PendingAuthSessionUpsertBulk { + _c.conflict = opts + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *PendingAuthSessionCreateBulk) OnConflictColumns(columns ...string) *PendingAuthSessionUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &PendingAuthSessionUpsertBulk{ + create: _c, + } +} + +// PendingAuthSessionUpsertBulk is the builder for "upsert"-ing +// a bulk of PendingAuthSession nodes. +type PendingAuthSessionUpsertBulk struct { + create *PendingAuthSessionCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) UpdateNewValues() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(pendingauthsession.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.PendingAuthSession.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *PendingAuthSessionUpsertBulk) Ignore() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *PendingAuthSessionUpsertBulk) DoNothing() *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the PendingAuthSessionCreateBulk.OnConflict +// documentation for more info. +func (u *PendingAuthSessionUpsertBulk) Update(set func(*PendingAuthSessionUpsert)) *PendingAuthSessionUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&PendingAuthSessionUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *PendingAuthSessionUpsertBulk) SetUpdatedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpdatedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetSessionToken sets the "session_token" field. +func (u *PendingAuthSessionUpsertBulk) SetSessionToken(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetSessionToken(v) + }) +} + +// UpdateSessionToken sets the "session_token" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateSessionToken() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateSessionToken() + }) +} + +// SetIntent sets the "intent" field. +func (u *PendingAuthSessionUpsertBulk) SetIntent(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetIntent(v) + }) +} + +// UpdateIntent sets the "intent" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateIntent() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateIntent() + }) +} + +// SetProviderType sets the "provider_type" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderType(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderType(v) + }) +} + +// UpdateProviderType sets the "provider_type" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderType() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderType() + }) +} + +// SetProviderKey sets the "provider_key" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderKey() + }) +} + +// SetProviderSubject sets the "provider_subject" field. +func (u *PendingAuthSessionUpsertBulk) SetProviderSubject(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetProviderSubject(v) + }) +} + +// UpdateProviderSubject sets the "provider_subject" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateProviderSubject() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateProviderSubject() + }) +} + +// SetTargetUserID sets the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) SetTargetUserID(v int64) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTargetUserID(v) + }) +} + +// UpdateTargetUserID sets the "target_user_id" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTargetUserID() + }) +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (u *PendingAuthSessionUpsertBulk) ClearTargetUserID() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTargetUserID() + }) +} + +// SetRedirectTo sets the "redirect_to" field. +func (u *PendingAuthSessionUpsertBulk) SetRedirectTo(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRedirectTo(v) + }) +} + +// UpdateRedirectTo sets the "redirect_to" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRedirectTo() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRedirectTo() + }) +} + +// SetResolvedEmail sets the "resolved_email" field. +func (u *PendingAuthSessionUpsertBulk) SetResolvedEmail(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetResolvedEmail(v) + }) +} + +// UpdateResolvedEmail sets the "resolved_email" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateResolvedEmail() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateResolvedEmail() + }) +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetRegistrationPasswordHash(v) + }) +} + +// UpdateRegistrationPasswordHash sets the "registration_password_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateRegistrationPasswordHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateRegistrationPasswordHash() + }) +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (u *PendingAuthSessionUpsertBulk) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetUpstreamIdentityClaims(v) + }) +} + +// UpdateUpstreamIdentityClaims sets the "upstream_identity_claims" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateUpstreamIdentityClaims() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateUpstreamIdentityClaims() + }) +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (u *PendingAuthSessionUpsertBulk) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetLocalFlowState(v) + }) +} + +// UpdateLocalFlowState sets the "local_flow_state" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateLocalFlowState() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateLocalFlowState() + }) +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (u *PendingAuthSessionUpsertBulk) SetBrowserSessionKey(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetBrowserSessionKey(v) + }) +} + +// UpdateBrowserSessionKey sets the "browser_session_key" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateBrowserSessionKey() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateBrowserSessionKey() + }) +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeHash(v string) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeHash(v) + }) +} + +// UpdateCompletionCodeHash sets the "completion_code_hash" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeHash() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeHash() + }) +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetCompletionCodeExpiresAt(v) + }) +} + +// UpdateCompletionCodeExpiresAt sets the "completion_code_expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateCompletionCodeExpiresAt() + }) +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearCompletionCodeExpiresAt() + }) +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetEmailVerifiedAt(v) + }) +} + +// UpdateEmailVerifiedAt sets the "email_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateEmailVerifiedAt() + }) +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearEmailVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearEmailVerifiedAt() + }) +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetPasswordVerifiedAt(v) + }) +} + +// UpdatePasswordVerifiedAt sets the "password_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdatePasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdatePasswordVerifiedAt() + }) +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearPasswordVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearPasswordVerifiedAt() + }) +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetTotpVerifiedAt(v) + }) +} + +// UpdateTotpVerifiedAt sets the "totp_verified_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateTotpVerifiedAt() + }) +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearTotpVerifiedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearTotpVerifiedAt() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *PendingAuthSessionUpsertBulk) SetExpiresAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateExpiresAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateExpiresAt() + }) +} + +// SetConsumedAt sets the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) SetConsumedAt(v time.Time) *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.SetConsumedAt(v) + }) +} + +// UpdateConsumedAt sets the "consumed_at" field to the value that was provided on create. +func (u *PendingAuthSessionUpsertBulk) UpdateConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.UpdateConsumedAt() + }) +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (u *PendingAuthSessionUpsertBulk) ClearConsumedAt() *PendingAuthSessionUpsertBulk { + return u.Update(func(s *PendingAuthSessionUpsert) { + s.ClearConsumedAt() + }) +} + +// Exec executes the query. +func (u *PendingAuthSessionUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the PendingAuthSessionCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for PendingAuthSessionCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *PendingAuthSessionUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_delete.go b/backend/ent/pendingauthsession_delete.go new file mode 100644 index 00000000..ee4fe605 --- /dev/null +++ b/backend/ent/pendingauthsession_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// PendingAuthSessionDelete is the builder for deleting a PendingAuthSession entity. +type PendingAuthSessionDelete struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDelete) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *PendingAuthSessionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *PendingAuthSessionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(pendingauthsession.Table, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// PendingAuthSessionDeleteOne is the builder for deleting a single PendingAuthSession entity. +type PendingAuthSessionDeleteOne struct { + _d *PendingAuthSessionDelete +} + +// Where appends a list predicates to the PendingAuthSessionDelete builder. +func (_d *PendingAuthSessionDeleteOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *PendingAuthSessionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{pendingauthsession.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *PendingAuthSessionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/pendingauthsession_query.go b/backend/ent/pendingauthsession_query.go new file mode 100644 index 00000000..78e29cd2 --- /dev/null +++ b/backend/ent/pendingauthsession_query.go @@ -0,0 +1,717 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "database/sql/driver" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionQuery is the builder for querying PendingAuthSession entities. +type PendingAuthSessionQuery struct { + config + ctx *QueryContext + order []pendingauthsession.OrderOption + inters []Interceptor + predicates []predicate.PendingAuthSession + withTargetUser *UserQuery + withAdoptionDecision *IdentityAdoptionDecisionQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the PendingAuthSessionQuery builder. +func (_q *PendingAuthSessionQuery) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *PendingAuthSessionQuery) Limit(limit int) *PendingAuthSessionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *PendingAuthSessionQuery) Offset(offset int) *PendingAuthSessionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *PendingAuthSessionQuery) Unique(unique bool) *PendingAuthSessionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *PendingAuthSessionQuery) Order(o ...pendingauthsession.OrderOption) *PendingAuthSessionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// QueryTargetUser chains the current query on the "target_user" edge. +func (_q *PendingAuthSessionQuery) QueryTargetUser() *UserQuery { + query := (&UserClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, pendingauthsession.TargetUserTable, pendingauthsession.TargetUserColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryAdoptionDecision chains the current query on the "adoption_decision" edge. +func (_q *PendingAuthSessionQuery) QueryAdoptionDecision() *IdentityAdoptionDecisionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(pendingauthsession.Table, pendingauthsession.FieldID, selector), + sqlgraph.To(identityadoptiondecision.Table, identityadoptiondecision.FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, pendingauthsession.AdoptionDecisionTable, pendingauthsession.AdoptionDecisionColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first PendingAuthSession entity from the query. +// Returns a *NotFoundError when no PendingAuthSession was found. +func (_q *PendingAuthSessionQuery) First(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{pendingauthsession.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstX(ctx context.Context) *PendingAuthSession { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first PendingAuthSession ID from the query. +// Returns a *NotFoundError when no PendingAuthSession ID was found. +func (_q *PendingAuthSessionQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{pendingauthsession.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single PendingAuthSession entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one PendingAuthSession entity is found. +// Returns a *NotFoundError when no PendingAuthSession entities are found. +func (_q *PendingAuthSessionQuery) Only(ctx context.Context) (*PendingAuthSession, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{pendingauthsession.Label} + default: + return nil, &NotSingularError{pendingauthsession.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyX(ctx context.Context) *PendingAuthSession { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only PendingAuthSession ID in the query. +// Returns a *NotSingularError when more than one PendingAuthSession ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *PendingAuthSessionQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{pendingauthsession.Label} + default: + err = &NotSingularError{pendingauthsession.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of PendingAuthSessions. +func (_q *PendingAuthSessionQuery) All(ctx context.Context) ([]*PendingAuthSession, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*PendingAuthSession, *PendingAuthSessionQuery]() + return withInterceptors[[]*PendingAuthSession](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) AllX(ctx context.Context) []*PendingAuthSession { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of PendingAuthSession IDs. +func (_q *PendingAuthSessionQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(pendingauthsession.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *PendingAuthSessionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*PendingAuthSessionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *PendingAuthSessionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *PendingAuthSessionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the PendingAuthSessionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *PendingAuthSessionQuery) Clone() *PendingAuthSessionQuery { + if _q == nil { + return nil + } + return &PendingAuthSessionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]pendingauthsession.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.PendingAuthSession{}, _q.predicates...), + withTargetUser: _q.withTargetUser.Clone(), + withAdoptionDecision: _q.withAdoptionDecision.Clone(), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// WithTargetUser tells the query-builder to eager-load the nodes that are connected to +// the "target_user" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithTargetUser(opts ...func(*UserQuery)) *PendingAuthSessionQuery { + query := (&UserClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withTargetUser = query + return _q +} + +// WithAdoptionDecision tells the query-builder to eager-load the nodes that are connected to +// the "adoption_decision" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *PendingAuthSessionQuery) WithAdoptionDecision(opts ...func(*IdentityAdoptionDecisionQuery)) *PendingAuthSessionQuery { + query := (&IdentityAdoptionDecisionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAdoptionDecision = query + return _q +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// GroupBy(pendingauthsession.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) GroupBy(field string, fields ...string) *PendingAuthSessionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &PendingAuthSessionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = pendingauthsession.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.PendingAuthSession.Query(). +// Select(pendingauthsession.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *PendingAuthSessionQuery) Select(fields ...string) *PendingAuthSessionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &PendingAuthSessionSelect{PendingAuthSessionQuery: _q} + sbuild.label = pendingauthsession.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a PendingAuthSessionSelect configured with the given aggregations. +func (_q *PendingAuthSessionQuery) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *PendingAuthSessionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !pendingauthsession.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*PendingAuthSession, error) { + var ( + nodes = []*PendingAuthSession{} + _spec = _q.querySpec() + loadedTypes = [2]bool{ + _q.withTargetUser != nil, + _q.withAdoptionDecision != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*PendingAuthSession).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &PendingAuthSession{config: _q.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := _q.withTargetUser; query != nil { + if err := _q.loadTargetUser(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *User) { n.Edges.TargetUser = e }); err != nil { + return nil, err + } + } + if query := _q.withAdoptionDecision; query != nil { + if err := _q.loadAdoptionDecision(ctx, query, nodes, nil, + func(n *PendingAuthSession, e *IdentityAdoptionDecision) { n.Edges.AdoptionDecision = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (_q *PendingAuthSessionQuery) loadTargetUser(ctx context.Context, query *UserQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *User)) error { + ids := make([]int64, 0, len(nodes)) + nodeids := make(map[int64][]*PendingAuthSession) + for i := range nodes { + if nodes[i].TargetUserID == nil { + continue + } + fk := *nodes[i].TargetUserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "target_user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (_q *PendingAuthSessionQuery) loadAdoptionDecision(ctx context.Context, query *IdentityAdoptionDecisionQuery, nodes []*PendingAuthSession, init func(*PendingAuthSession), assign func(*PendingAuthSession, *IdentityAdoptionDecision)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*PendingAuthSession) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(identityadoptiondecision.FieldPendingAuthSessionID) + } + query.Where(predicate.IdentityAdoptionDecision(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(pendingauthsession.AdoptionDecisionColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.PendingAuthSessionID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "pending_auth_session_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} + +func (_q *PendingAuthSessionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *PendingAuthSessionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for i := range fields { + if fields[i] != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if _q.withTargetUser != nil { + _spec.Node.AddColumnOnce(pendingauthsession.FieldTargetUserID) + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *PendingAuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(pendingauthsession.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = pendingauthsession.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *PendingAuthSessionQuery) ForUpdate(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *PendingAuthSessionQuery) ForShare(opts ...sql.LockOption) *PendingAuthSessionQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// PendingAuthSessionGroupBy is the group-by builder for PendingAuthSession entities. +type PendingAuthSessionGroupBy struct { + selector + build *PendingAuthSessionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *PendingAuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *PendingAuthSessionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *PendingAuthSessionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *PendingAuthSessionGroupBy) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// PendingAuthSessionSelect is the builder for selecting fields of PendingAuthSession entities. +type PendingAuthSessionSelect struct { + *PendingAuthSessionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *PendingAuthSessionSelect) Aggregate(fns ...AggregateFunc) *PendingAuthSessionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *PendingAuthSessionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*PendingAuthSessionQuery, *PendingAuthSessionSelect](ctx, _s.PendingAuthSessionQuery, _s, _s.inters, v) +} + +func (_s *PendingAuthSessionSelect) sqlScan(ctx context.Context, root *PendingAuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/pendingauthsession_update.go b/backend/ent/pendingauthsession_update.go new file mode 100644 index 00000000..00066f69 --- /dev/null +++ b/backend/ent/pendingauthsession_update.go @@ -0,0 +1,1178 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/user" +) + +// PendingAuthSessionUpdate is the builder for updating PendingAuthSession entities. +type PendingAuthSessionUpdate struct { + config + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdate) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdate) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdate) SetSessionToken(v string) *PendingAuthSessionUpdate { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableSessionToken(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdate) SetIntent(v string) *PendingAuthSessionUpdate { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableIntent(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdate) SetProviderType(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderType(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdate) SetProviderKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdate) SetProviderSubject(v string) *PendingAuthSessionUpdate { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) SetTargetUserID(v int64) *PendingAuthSessionUpdate { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdate) ClearTargetUserID() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdate) SetRedirectTo(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdate) SetResolvedEmail(v string) *PendingAuthSessionUpdate { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdate) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdate) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdate) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdate { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdate) SetBrowserSessionKey(v string) *PendingAuthSessionUpdate { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeHash(v string) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdate) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdate { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearEmailVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearPasswordVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdate) ClearTotpVerifiedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdate) SetExpiresAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) SetConsumedAt(v time.Time) *PendingAuthSessionUpdate { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdate { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdate) ClearConsumedAt() *PendingAuthSessionUpdate { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) SetTargetUser(v *User) *PendingAuthSessionUpdate { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdate { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdate) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdate { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdate { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdate) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdate) ClearTargetUser() *PendingAuthSessionUpdate { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdate) ClearAdoptionDecision() *PendingAuthSessionUpdate { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *PendingAuthSessionUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *PendingAuthSessionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdate) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// PendingAuthSessionUpdateOne is the builder for updating a single PendingAuthSession entity. +type PendingAuthSessionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *PendingAuthSessionMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *PendingAuthSessionUpdateOne) SetUpdatedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetSessionToken sets the "session_token" field. +func (_u *PendingAuthSessionUpdateOne) SetSessionToken(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetSessionToken(v) + return _u +} + +// SetNillableSessionToken sets the "session_token" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableSessionToken(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetSessionToken(*v) + } + return _u +} + +// SetIntent sets the "intent" field. +func (_u *PendingAuthSessionUpdateOne) SetIntent(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetIntent(v) + return _u +} + +// SetNillableIntent sets the "intent" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableIntent(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetIntent(*v) + } + return _u +} + +// SetProviderType sets the "provider_type" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderType(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderType(v) + return _u +} + +// SetNillableProviderType sets the "provider_type" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderType(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderType(*v) + } + return _u +} + +// SetProviderKey sets the "provider_key" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// SetProviderSubject sets the "provider_subject" field. +func (_u *PendingAuthSessionUpdateOne) SetProviderSubject(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetProviderSubject(v) + return _u +} + +// SetNillableProviderSubject sets the "provider_subject" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableProviderSubject(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetProviderSubject(*v) + } + return _u +} + +// SetTargetUserID sets the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) SetTargetUserID(v int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetTargetUserID(v) + return _u +} + +// SetNillableTargetUserID sets the "target_user_id" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTargetUserID(v *int64) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTargetUserID(*v) + } + return _u +} + +// ClearTargetUserID clears the value of the "target_user_id" field. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUserID() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUserID() + return _u +} + +// SetRedirectTo sets the "redirect_to" field. +func (_u *PendingAuthSessionUpdateOne) SetRedirectTo(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRedirectTo(v) + return _u +} + +// SetNillableRedirectTo sets the "redirect_to" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRedirectTo(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRedirectTo(*v) + } + return _u +} + +// SetResolvedEmail sets the "resolved_email" field. +func (_u *PendingAuthSessionUpdateOne) SetResolvedEmail(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetResolvedEmail(v) + return _u +} + +// SetNillableResolvedEmail sets the "resolved_email" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableResolvedEmail(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetResolvedEmail(*v) + } + return _u +} + +// SetRegistrationPasswordHash sets the "registration_password_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetRegistrationPasswordHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetRegistrationPasswordHash(v) + return _u +} + +// SetNillableRegistrationPasswordHash sets the "registration_password_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableRegistrationPasswordHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetRegistrationPasswordHash(*v) + } + return _u +} + +// SetUpstreamIdentityClaims sets the "upstream_identity_claims" field. +func (_u *PendingAuthSessionUpdateOne) SetUpstreamIdentityClaims(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetUpstreamIdentityClaims(v) + return _u +} + +// SetLocalFlowState sets the "local_flow_state" field. +func (_u *PendingAuthSessionUpdateOne) SetLocalFlowState(v map[string]interface{}) *PendingAuthSessionUpdateOne { + _u.mutation.SetLocalFlowState(v) + return _u +} + +// SetBrowserSessionKey sets the "browser_session_key" field. +func (_u *PendingAuthSessionUpdateOne) SetBrowserSessionKey(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetBrowserSessionKey(v) + return _u +} + +// SetNillableBrowserSessionKey sets the "browser_session_key" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableBrowserSessionKey(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetBrowserSessionKey(*v) + } + return _u +} + +// SetCompletionCodeHash sets the "completion_code_hash" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeHash(v string) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeHash(v) + return _u +} + +// SetNillableCompletionCodeHash sets the "completion_code_hash" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeHash(v *string) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeHash(*v) + } + return _u +} + +// SetCompletionCodeExpiresAt sets the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetCompletionCodeExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetCompletionCodeExpiresAt(v) + return _u +} + +// SetNillableCompletionCodeExpiresAt sets the "completion_code_expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableCompletionCodeExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetCompletionCodeExpiresAt(*v) + } + return _u +} + +// ClearCompletionCodeExpiresAt clears the value of the "completion_code_expires_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearCompletionCodeExpiresAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearCompletionCodeExpiresAt() + return _u +} + +// SetEmailVerifiedAt sets the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetEmailVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetEmailVerifiedAt(v) + return _u +} + +// SetNillableEmailVerifiedAt sets the "email_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableEmailVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetEmailVerifiedAt(*v) + } + return _u +} + +// ClearEmailVerifiedAt clears the value of the "email_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearEmailVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearEmailVerifiedAt() + return _u +} + +// SetPasswordVerifiedAt sets the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetPasswordVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetPasswordVerifiedAt(v) + return _u +} + +// SetNillablePasswordVerifiedAt sets the "password_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillablePasswordVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetPasswordVerifiedAt(*v) + } + return _u +} + +// ClearPasswordVerifiedAt clears the value of the "password_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearPasswordVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearPasswordVerifiedAt() + return _u +} + +// SetTotpVerifiedAt sets the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) SetTotpVerifiedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetTotpVerifiedAt(v) + return _u +} + +// SetNillableTotpVerifiedAt sets the "totp_verified_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableTotpVerifiedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetTotpVerifiedAt(*v) + } + return _u +} + +// ClearTotpVerifiedAt clears the value of the "totp_verified_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearTotpVerifiedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTotpVerifiedAt() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *PendingAuthSessionUpdateOne) SetExpiresAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableExpiresAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// SetConsumedAt sets the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) SetConsumedAt(v time.Time) *PendingAuthSessionUpdateOne { + _u.mutation.SetConsumedAt(v) + return _u +} + +// SetNillableConsumedAt sets the "consumed_at" field if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableConsumedAt(v *time.Time) *PendingAuthSessionUpdateOne { + if v != nil { + _u.SetConsumedAt(*v) + } + return _u +} + +// ClearConsumedAt clears the value of the "consumed_at" field. +func (_u *PendingAuthSessionUpdateOne) ClearConsumedAt() *PendingAuthSessionUpdateOne { + _u.mutation.ClearConsumedAt() + return _u +} + +// SetTargetUser sets the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) SetTargetUser(v *User) *PendingAuthSessionUpdateOne { + return _u.SetTargetUserID(v.ID) +} + +// SetAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecisionID(id int64) *PendingAuthSessionUpdateOne { + _u.mutation.SetAdoptionDecisionID(id) + return _u +} + +// SetNillableAdoptionDecisionID sets the "adoption_decision" edge to the IdentityAdoptionDecision entity by ID if the given value is not nil. +func (_u *PendingAuthSessionUpdateOne) SetNillableAdoptionDecisionID(id *int64) *PendingAuthSessionUpdateOne { + if id != nil { + _u = _u.SetAdoptionDecisionID(*id) + } + return _u +} + +// SetAdoptionDecision sets the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) SetAdoptionDecision(v *IdentityAdoptionDecision) *PendingAuthSessionUpdateOne { + return _u.SetAdoptionDecisionID(v.ID) +} + +// Mutation returns the PendingAuthSessionMutation object of the builder. +func (_u *PendingAuthSessionUpdateOne) Mutation() *PendingAuthSessionMutation { + return _u.mutation +} + +// ClearTargetUser clears the "target_user" edge to the User entity. +func (_u *PendingAuthSessionUpdateOne) ClearTargetUser() *PendingAuthSessionUpdateOne { + _u.mutation.ClearTargetUser() + return _u +} + +// ClearAdoptionDecision clears the "adoption_decision" edge to the IdentityAdoptionDecision entity. +func (_u *PendingAuthSessionUpdateOne) ClearAdoptionDecision() *PendingAuthSessionUpdateOne { + _u.mutation.ClearAdoptionDecision() + return _u +} + +// Where appends a list predicates to the PendingAuthSessionUpdate builder. +func (_u *PendingAuthSessionUpdateOne) Where(ps ...predicate.PendingAuthSession) *PendingAuthSessionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *PendingAuthSessionUpdateOne) Select(field string, fields ...string) *PendingAuthSessionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated PendingAuthSession entity. +func (_u *PendingAuthSessionUpdateOne) Save(ctx context.Context) (*PendingAuthSession, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) SaveX(ctx context.Context) *PendingAuthSession { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *PendingAuthSessionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *PendingAuthSessionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *PendingAuthSessionUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := pendingauthsession.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *PendingAuthSessionUpdateOne) check() error { + if v, ok := _u.mutation.SessionToken(); ok { + if err := pendingauthsession.SessionTokenValidator(v); err != nil { + return &ValidationError{Name: "session_token", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.session_token": %w`, err)} + } + } + if v, ok := _u.mutation.Intent(); ok { + if err := pendingauthsession.IntentValidator(v); err != nil { + return &ValidationError{Name: "intent", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.intent": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderType(); ok { + if err := pendingauthsession.ProviderTypeValidator(v); err != nil { + return &ValidationError{Name: "provider_type", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_type": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := pendingauthsession.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_key": %w`, err)} + } + } + if v, ok := _u.mutation.ProviderSubject(); ok { + if err := pendingauthsession.ProviderSubjectValidator(v); err != nil { + return &ValidationError{Name: "provider_subject", err: fmt.Errorf(`ent: validator failed for field "PendingAuthSession.provider_subject": %w`, err)} + } + } + return nil +} + +func (_u *PendingAuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *PendingAuthSession, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(pendingauthsession.Table, pendingauthsession.Columns, sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PendingAuthSession.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, pendingauthsession.FieldID) + for _, f := range fields { + if !pendingauthsession.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != pendingauthsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(pendingauthsession.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.SessionToken(); ok { + _spec.SetField(pendingauthsession.FieldSessionToken, field.TypeString, value) + } + if value, ok := _u.mutation.Intent(); ok { + _spec.SetField(pendingauthsession.FieldIntent, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderType(); ok { + _spec.SetField(pendingauthsession.FieldProviderType, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(pendingauthsession.FieldProviderKey, field.TypeString, value) + } + if value, ok := _u.mutation.ProviderSubject(); ok { + _spec.SetField(pendingauthsession.FieldProviderSubject, field.TypeString, value) + } + if value, ok := _u.mutation.RedirectTo(); ok { + _spec.SetField(pendingauthsession.FieldRedirectTo, field.TypeString, value) + } + if value, ok := _u.mutation.ResolvedEmail(); ok { + _spec.SetField(pendingauthsession.FieldResolvedEmail, field.TypeString, value) + } + if value, ok := _u.mutation.RegistrationPasswordHash(); ok { + _spec.SetField(pendingauthsession.FieldRegistrationPasswordHash, field.TypeString, value) + } + if value, ok := _u.mutation.UpstreamIdentityClaims(); ok { + _spec.SetField(pendingauthsession.FieldUpstreamIdentityClaims, field.TypeJSON, value) + } + if value, ok := _u.mutation.LocalFlowState(); ok { + _spec.SetField(pendingauthsession.FieldLocalFlowState, field.TypeJSON, value) + } + if value, ok := _u.mutation.BrowserSessionKey(); ok { + _spec.SetField(pendingauthsession.FieldBrowserSessionKey, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeHash(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeHash, field.TypeString, value) + } + if value, ok := _u.mutation.CompletionCodeExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime, value) + } + if _u.mutation.CompletionCodeExpiresAtCleared() { + _spec.ClearField(pendingauthsession.FieldCompletionCodeExpiresAt, field.TypeTime) + } + if value, ok := _u.mutation.EmailVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime, value) + } + if _u.mutation.EmailVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldEmailVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.PasswordVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime, value) + } + if _u.mutation.PasswordVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldPasswordVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.TotpVerifiedAt(); ok { + _spec.SetField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime, value) + } + if _u.mutation.TotpVerifiedAtCleared() { + _spec.ClearField(pendingauthsession.FieldTotpVerifiedAt, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(pendingauthsession.FieldExpiresAt, field.TypeTime, value) + } + if value, ok := _u.mutation.ConsumedAt(); ok { + _spec.SetField(pendingauthsession.FieldConsumedAt, field.TypeTime, value) + } + if _u.mutation.ConsumedAtCleared() { + _spec.ClearField(pendingauthsession.FieldConsumedAt, field.TypeTime) + } + if _u.mutation.TargetUserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.TargetUserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: pendingauthsession.TargetUserTable, + Columns: []string{pendingauthsession.TargetUserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.AdoptionDecisionCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AdoptionDecisionIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2O, + Inverse: false, + Table: pendingauthsession.AdoptionDecisionTable, + Columns: []string{pendingauthsession.AdoptionDecisionColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(identityadoptiondecision.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _node = &PendingAuthSession{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{pendingauthsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index ef551940..0aa90b90 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,12 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// AuthIdentity is the predicate function for authidentity builders. +type AuthIdentity func(*sql.Selector) + +// AuthIdentityChannel is the predicate function for authidentitychannel builders. +type AuthIdentityChannel func(*sql.Selector) + // ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. type ErrorPassthroughRule func(*sql.Selector) @@ -30,6 +36,9 @@ type Group func(*sql.Selector) // IdempotencyRecord is the predicate function for idempotencyrecord builders. type IdempotencyRecord func(*sql.Selector) +// IdentityAdoptionDecision is the predicate function for identityadoptiondecision builders. +type IdentityAdoptionDecision func(*sql.Selector) + // PaymentAuditLog is the predicate function for paymentauditlog builders. type PaymentAuditLog func(*sql.Selector) @@ -39,6 +48,9 @@ type PaymentOrder func(*sql.Selector) // PaymentProviderInstance is the predicate function for paymentproviderinstance builders. type PaymentProviderInstance func(*sql.Selector) +// PendingAuthSession is the predicate function for pendingauthsession builders. +type PendingAuthSession func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fbdd08c7..268e9ddb 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,12 +10,16 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -309,6 +313,120 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + authidentityMixin := schema.AuthIdentity{}.Mixin() + authidentityMixinFields0 := authidentityMixin[0].Fields() + _ = authidentityMixinFields0 + authidentityFields := schema.AuthIdentity{}.Fields() + _ = authidentityFields + // authidentityDescCreatedAt is the schema descriptor for created_at field. + authidentityDescCreatedAt := authidentityMixinFields0[0].Descriptor() + // authidentity.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentity.DefaultCreatedAt = authidentityDescCreatedAt.Default.(func() time.Time) + // authidentityDescUpdatedAt is the schema descriptor for updated_at field. + authidentityDescUpdatedAt := authidentityMixinFields0[1].Descriptor() + // authidentity.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentity.DefaultUpdatedAt = authidentityDescUpdatedAt.Default.(func() time.Time) + // authidentity.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentity.UpdateDefaultUpdatedAt = authidentityDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentityDescProviderType is the schema descriptor for provider_type field. + authidentityDescProviderType := authidentityFields[1].Descriptor() + // authidentity.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentity.ProviderTypeValidator = func() func(string) error { + validators := authidentityDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentityDescProviderKey is the schema descriptor for provider_key field. + authidentityDescProviderKey := authidentityFields[2].Descriptor() + // authidentity.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentity.ProviderKeyValidator = authidentityDescProviderKey.Validators[0].(func(string) error) + // authidentityDescProviderSubject is the schema descriptor for provider_subject field. + authidentityDescProviderSubject := authidentityFields[3].Descriptor() + // authidentity.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + authidentity.ProviderSubjectValidator = authidentityDescProviderSubject.Validators[0].(func(string) error) + // authidentityDescMetadata is the schema descriptor for metadata field. + authidentityDescMetadata := authidentityFields[6].Descriptor() + // authidentity.DefaultMetadata holds the default value on creation for the metadata field. + authidentity.DefaultMetadata = authidentityDescMetadata.Default.(func() map[string]interface{}) + authidentitychannelMixin := schema.AuthIdentityChannel{}.Mixin() + authidentitychannelMixinFields0 := authidentitychannelMixin[0].Fields() + _ = authidentitychannelMixinFields0 + authidentitychannelFields := schema.AuthIdentityChannel{}.Fields() + _ = authidentitychannelFields + // authidentitychannelDescCreatedAt is the schema descriptor for created_at field. + authidentitychannelDescCreatedAt := authidentitychannelMixinFields0[0].Descriptor() + // authidentitychannel.DefaultCreatedAt holds the default value on creation for the created_at field. + authidentitychannel.DefaultCreatedAt = authidentitychannelDescCreatedAt.Default.(func() time.Time) + // authidentitychannelDescUpdatedAt is the schema descriptor for updated_at field. + authidentitychannelDescUpdatedAt := authidentitychannelMixinFields0[1].Descriptor() + // authidentitychannel.DefaultUpdatedAt holds the default value on creation for the updated_at field. + authidentitychannel.DefaultUpdatedAt = authidentitychannelDescUpdatedAt.Default.(func() time.Time) + // authidentitychannel.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + authidentitychannel.UpdateDefaultUpdatedAt = authidentitychannelDescUpdatedAt.UpdateDefault.(func() time.Time) + // authidentitychannelDescProviderType is the schema descriptor for provider_type field. + authidentitychannelDescProviderType := authidentitychannelFields[1].Descriptor() + // authidentitychannel.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + authidentitychannel.ProviderTypeValidator = func() func(string) error { + validators := authidentitychannelDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescProviderKey is the schema descriptor for provider_key field. + authidentitychannelDescProviderKey := authidentitychannelFields[2].Descriptor() + // authidentitychannel.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + authidentitychannel.ProviderKeyValidator = authidentitychannelDescProviderKey.Validators[0].(func(string) error) + // authidentitychannelDescChannel is the schema descriptor for channel field. + authidentitychannelDescChannel := authidentitychannelFields[3].Descriptor() + // authidentitychannel.ChannelValidator is a validator for the "channel" field. It is called by the builders before save. + authidentitychannel.ChannelValidator = func() func(string) error { + validators := authidentitychannelDescChannel.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(channel string) error { + for _, fn := range fns { + if err := fn(channel); err != nil { + return err + } + } + return nil + } + }() + // authidentitychannelDescChannelAppID is the schema descriptor for channel_app_id field. + authidentitychannelDescChannelAppID := authidentitychannelFields[4].Descriptor() + // authidentitychannel.ChannelAppIDValidator is a validator for the "channel_app_id" field. It is called by the builders before save. + authidentitychannel.ChannelAppIDValidator = authidentitychannelDescChannelAppID.Validators[0].(func(string) error) + // authidentitychannelDescChannelSubject is the schema descriptor for channel_subject field. + authidentitychannelDescChannelSubject := authidentitychannelFields[5].Descriptor() + // authidentitychannel.ChannelSubjectValidator is a validator for the "channel_subject" field. It is called by the builders before save. + authidentitychannel.ChannelSubjectValidator = authidentitychannelDescChannelSubject.Validators[0].(func(string) error) + // authidentitychannelDescMetadata is the schema descriptor for metadata field. + authidentitychannelDescMetadata := authidentitychannelFields[6].Descriptor() + // authidentitychannel.DefaultMetadata holds the default value on creation for the metadata field. + authidentitychannel.DefaultMetadata = authidentitychannelDescMetadata.Default.(func() map[string]interface{}) errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() _ = errorpassthroughruleMixinFields0 @@ -512,6 +630,33 @@ func init() { idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) + identityadoptiondecisionMixin := schema.IdentityAdoptionDecision{}.Mixin() + identityadoptiondecisionMixinFields0 := identityadoptiondecisionMixin[0].Fields() + _ = identityadoptiondecisionMixinFields0 + identityadoptiondecisionFields := schema.IdentityAdoptionDecision{}.Fields() + _ = identityadoptiondecisionFields + // identityadoptiondecisionDescCreatedAt is the schema descriptor for created_at field. + identityadoptiondecisionDescCreatedAt := identityadoptiondecisionMixinFields0[0].Descriptor() + // identityadoptiondecision.DefaultCreatedAt holds the default value on creation for the created_at field. + identityadoptiondecision.DefaultCreatedAt = identityadoptiondecisionDescCreatedAt.Default.(func() time.Time) + // identityadoptiondecisionDescUpdatedAt is the schema descriptor for updated_at field. + identityadoptiondecisionDescUpdatedAt := identityadoptiondecisionMixinFields0[1].Descriptor() + // identityadoptiondecision.DefaultUpdatedAt holds the default value on creation for the updated_at field. + identityadoptiondecision.DefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.Default.(func() time.Time) + // identityadoptiondecision.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + identityadoptiondecision.UpdateDefaultUpdatedAt = identityadoptiondecisionDescUpdatedAt.UpdateDefault.(func() time.Time) + // identityadoptiondecisionDescAdoptDisplayName is the schema descriptor for adopt_display_name field. + identityadoptiondecisionDescAdoptDisplayName := identityadoptiondecisionFields[2].Descriptor() + // identityadoptiondecision.DefaultAdoptDisplayName holds the default value on creation for the adopt_display_name field. + identityadoptiondecision.DefaultAdoptDisplayName = identityadoptiondecisionDescAdoptDisplayName.Default.(bool) + // identityadoptiondecisionDescAdoptAvatar is the schema descriptor for adopt_avatar field. + identityadoptiondecisionDescAdoptAvatar := identityadoptiondecisionFields[3].Descriptor() + // identityadoptiondecision.DefaultAdoptAvatar holds the default value on creation for the adopt_avatar field. + identityadoptiondecision.DefaultAdoptAvatar = identityadoptiondecisionDescAdoptAvatar.Default.(bool) + // identityadoptiondecisionDescDecidedAt is the schema descriptor for decided_at field. + identityadoptiondecisionDescDecidedAt := identityadoptiondecisionFields[4].Descriptor() + // identityadoptiondecision.DefaultDecidedAt holds the default value on creation for the decided_at field. + identityadoptiondecision.DefaultDecidedAt = identityadoptiondecisionDescDecidedAt.Default.(func() time.Time) paymentauditlogFields := schema.PaymentAuditLog{}.Fields() _ = paymentauditlogFields // paymentauditlogDescOrderID is the schema descriptor for order_id field. @@ -682,6 +827,113 @@ func init() { paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. paymentproviderinstance.UpdateDefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.UpdateDefault.(func() time.Time) + pendingauthsessionMixin := schema.PendingAuthSession{}.Mixin() + pendingauthsessionMixinFields0 := pendingauthsessionMixin[0].Fields() + _ = pendingauthsessionMixinFields0 + pendingauthsessionFields := schema.PendingAuthSession{}.Fields() + _ = pendingauthsessionFields + // pendingauthsessionDescCreatedAt is the schema descriptor for created_at field. + pendingauthsessionDescCreatedAt := pendingauthsessionMixinFields0[0].Descriptor() + // pendingauthsession.DefaultCreatedAt holds the default value on creation for the created_at field. + pendingauthsession.DefaultCreatedAt = pendingauthsessionDescCreatedAt.Default.(func() time.Time) + // pendingauthsessionDescUpdatedAt is the schema descriptor for updated_at field. + pendingauthsessionDescUpdatedAt := pendingauthsessionMixinFields0[1].Descriptor() + // pendingauthsession.DefaultUpdatedAt holds the default value on creation for the updated_at field. + pendingauthsession.DefaultUpdatedAt = pendingauthsessionDescUpdatedAt.Default.(func() time.Time) + // pendingauthsession.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + pendingauthsession.UpdateDefaultUpdatedAt = pendingauthsessionDescUpdatedAt.UpdateDefault.(func() time.Time) + // pendingauthsessionDescSessionToken is the schema descriptor for session_token field. + pendingauthsessionDescSessionToken := pendingauthsessionFields[0].Descriptor() + // pendingauthsession.SessionTokenValidator is a validator for the "session_token" field. It is called by the builders before save. + pendingauthsession.SessionTokenValidator = func() func(string) error { + validators := pendingauthsessionDescSessionToken.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(session_token string) error { + for _, fn := range fns { + if err := fn(session_token); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescIntent is the schema descriptor for intent field. + pendingauthsessionDescIntent := pendingauthsessionFields[1].Descriptor() + // pendingauthsession.IntentValidator is a validator for the "intent" field. It is called by the builders before save. + pendingauthsession.IntentValidator = func() func(string) error { + validators := pendingauthsessionDescIntent.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(intent string) error { + for _, fn := range fns { + if err := fn(intent); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderType is the schema descriptor for provider_type field. + pendingauthsessionDescProviderType := pendingauthsessionFields[2].Descriptor() + // pendingauthsession.ProviderTypeValidator is a validator for the "provider_type" field. It is called by the builders before save. + pendingauthsession.ProviderTypeValidator = func() func(string) error { + validators := pendingauthsessionDescProviderType.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + validators[2].(func(string) error), + } + return func(provider_type string) error { + for _, fn := range fns { + if err := fn(provider_type); err != nil { + return err + } + } + return nil + } + }() + // pendingauthsessionDescProviderKey is the schema descriptor for provider_key field. + pendingauthsessionDescProviderKey := pendingauthsessionFields[3].Descriptor() + // pendingauthsession.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + pendingauthsession.ProviderKeyValidator = pendingauthsessionDescProviderKey.Validators[0].(func(string) error) + // pendingauthsessionDescProviderSubject is the schema descriptor for provider_subject field. + pendingauthsessionDescProviderSubject := pendingauthsessionFields[4].Descriptor() + // pendingauthsession.ProviderSubjectValidator is a validator for the "provider_subject" field. It is called by the builders before save. + pendingauthsession.ProviderSubjectValidator = pendingauthsessionDescProviderSubject.Validators[0].(func(string) error) + // pendingauthsessionDescRedirectTo is the schema descriptor for redirect_to field. + pendingauthsessionDescRedirectTo := pendingauthsessionFields[6].Descriptor() + // pendingauthsession.DefaultRedirectTo holds the default value on creation for the redirect_to field. + pendingauthsession.DefaultRedirectTo = pendingauthsessionDescRedirectTo.Default.(string) + // pendingauthsessionDescResolvedEmail is the schema descriptor for resolved_email field. + pendingauthsessionDescResolvedEmail := pendingauthsessionFields[7].Descriptor() + // pendingauthsession.DefaultResolvedEmail holds the default value on creation for the resolved_email field. + pendingauthsession.DefaultResolvedEmail = pendingauthsessionDescResolvedEmail.Default.(string) + // pendingauthsessionDescRegistrationPasswordHash is the schema descriptor for registration_password_hash field. + pendingauthsessionDescRegistrationPasswordHash := pendingauthsessionFields[8].Descriptor() + // pendingauthsession.DefaultRegistrationPasswordHash holds the default value on creation for the registration_password_hash field. + pendingauthsession.DefaultRegistrationPasswordHash = pendingauthsessionDescRegistrationPasswordHash.Default.(string) + // pendingauthsessionDescUpstreamIdentityClaims is the schema descriptor for upstream_identity_claims field. + pendingauthsessionDescUpstreamIdentityClaims := pendingauthsessionFields[9].Descriptor() + // pendingauthsession.DefaultUpstreamIdentityClaims holds the default value on creation for the upstream_identity_claims field. + pendingauthsession.DefaultUpstreamIdentityClaims = pendingauthsessionDescUpstreamIdentityClaims.Default.(func() map[string]interface{}) + // pendingauthsessionDescLocalFlowState is the schema descriptor for local_flow_state field. + pendingauthsessionDescLocalFlowState := pendingauthsessionFields[10].Descriptor() + // pendingauthsession.DefaultLocalFlowState holds the default value on creation for the local_flow_state field. + pendingauthsession.DefaultLocalFlowState = pendingauthsessionDescLocalFlowState.Default.(func() map[string]interface{}) + // pendingauthsessionDescBrowserSessionKey is the schema descriptor for browser_session_key field. + pendingauthsessionDescBrowserSessionKey := pendingauthsessionFields[11].Descriptor() + // pendingauthsession.DefaultBrowserSessionKey holds the default value on creation for the browser_session_key field. + pendingauthsession.DefaultBrowserSessionKey = pendingauthsessionDescBrowserSessionKey.Default.(string) + // pendingauthsessionDescCompletionCodeHash is the schema descriptor for completion_code_hash field. + pendingauthsessionDescCompletionCodeHash := pendingauthsessionFields[12].Descriptor() + // pendingauthsession.DefaultCompletionCodeHash holds the default value on creation for the completion_code_hash field. + pendingauthsession.DefaultCompletionCodeHash = pendingauthsessionDescCompletionCodeHash.Default.(string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -1297,20 +1549,26 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSignupSource is the schema descriptor for signup_source field. + userDescSignupSource := userFields[11].Descriptor() + // user.DefaultSignupSource holds the default value on creation for the signup_source field. + user.DefaultSignupSource = userDescSignupSource.Default.(string) + // user.SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + user.SignupSourceValidator = userDescSignupSource.Validators[0].(func(string) error) // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field. - userDescBalanceNotifyEnabled := userFields[11].Descriptor() + userDescBalanceNotifyEnabled := userFields[14].Descriptor() // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field. user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool) // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field. - userDescBalanceNotifyThresholdType := userFields[12].Descriptor() + userDescBalanceNotifyThresholdType := userFields[15].Descriptor() // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field. user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string) // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field. - userDescBalanceNotifyExtraEmails := userFields[14].Descriptor() + userDescBalanceNotifyExtraEmails := userFields[17].Descriptor() // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field. user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string) // userDescTotalRecharged is the schema descriptor for total_recharged field. - userDescTotalRecharged := userFields[15].Descriptor() + userDescTotalRecharged := userFields[18].Descriptor() // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go new file mode 100644 index 00000000..e4b9ac90 --- /dev/null +++ b/backend/ent/schema/auth_identity.go @@ -0,0 +1,93 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var authProviderTypes = map[string]struct{}{ + "email": {}, + "linuxdo": {}, + "oidc": {}, + "wechat": {}, +} + +func validateAuthProviderType(value string) error { + if _, ok := authProviderTypes[value]; ok { + return nil + } + return fmt.Errorf("invalid auth provider type %q", value) +} + +// AuthIdentity stores the canonical login identity for an account. +type AuthIdentity struct { + ent.Schema +} + +func (AuthIdentity) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identities"}, + } +} + +func (AuthIdentity) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentity) Fields() []ent.Field { + return []ent.Field{ + field.Int64("user_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.String("issuer"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentity) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type). + Ref("auth_identities"). + Field("user_id"). + Required(). + Unique(), + edge.To("channels", AuthIdentityChannel.Type), + edge.To("adoption_decisions", IdentityAdoptionDecision.Type), + } +} + +func (AuthIdentity) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "provider_subject").Unique(), + index.Fields("user_id"), + index.Fields("user_id", "provider_type"), + } +} diff --git a/backend/ent/schema/auth_identity_channel.go b/backend/ent/schema/auth_identity_channel.go new file mode 100644 index 00000000..69f2ad02 --- /dev/null +++ b/backend/ent/schema/auth_identity_channel.go @@ -0,0 +1,72 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// AuthIdentityChannel stores channel-scoped identifiers for a canonical identity. +type AuthIdentityChannel struct { + ent.Schema +} + +func (AuthIdentityChannel) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "auth_identity_channels"}, + } +} + +func (AuthIdentityChannel) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (AuthIdentityChannel) Fields() []ent.Field { + return []ent.Field{ + field.Int64("identity_id"), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel"). + MaxLen(20). + NotEmpty(), + field.String("channel_app_id"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("channel_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("metadata", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + } +} + +func (AuthIdentityChannel) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("identity", AuthIdentity.Type). + Ref("channels"). + Field("identity_id"). + Required(). + Unique(), + } +} + +func (AuthIdentityChannel) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("provider_type", "provider_key", "channel", "channel_app_id", "channel_subject").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go new file mode 100644 index 00000000..de55dd69 --- /dev/null +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -0,0 +1,124 @@ +package schema + +import ( + "testing" + + "entgo.io/ent/entc/load" + "github.com/stretchr/testify/require" +) + +func TestAuthIdentityFoundationSchemas(t *testing.T) { + spec, err := (&load.Config{Path: "."}).Load() + require.NoError(t, err) + + schemas := map[string]*load.Schema{} + for _, schema := range spec.Schemas { + schemas[schema.Name] = schema + } + + authIdentity := requireSchema(t, schemas, "AuthIdentity") + requireSchemaFields(t, authIdentity, + "user_id", + "provider_type", + "provider_key", + "provider_subject", + "verified_at", + "issuer", + "metadata", + ) + requireHasUniqueIndex(t, authIdentity, "provider_type", "provider_key", "provider_subject") + + authIdentityChannel := requireSchema(t, schemas, "AuthIdentityChannel") + requireSchemaFields(t, authIdentityChannel, + "identity_id", + "provider_type", + "provider_key", + "channel", + "channel_app_id", + "channel_subject", + "metadata", + ) + requireHasUniqueIndex(t, authIdentityChannel, "provider_type", "provider_key", "channel", "channel_app_id", "channel_subject") + + pendingAuthSession := requireSchema(t, schemas, "PendingAuthSession") + requireSchemaFields(t, pendingAuthSession, + "intent", + "provider_type", + "provider_key", + "provider_subject", + "target_user_id", + "redirect_to", + "resolved_email", + "registration_password_hash", + "upstream_identity_claims", + "local_flow_state", + "browser_session_key", + "completion_code_hash", + "completion_code_expires_at", + "email_verified_at", + "password_verified_at", + "totp_verified_at", + "expires_at", + "consumed_at", + ) + + adoptionDecision := requireSchema(t, schemas, "IdentityAdoptionDecision") + requireSchemaFields(t, adoptionDecision, + "pending_auth_session_id", + "identity_id", + "adopt_display_name", + "adopt_avatar", + "decided_at", + ) + requireHasUniqueIndex(t, adoptionDecision, "pending_auth_session_id") + + userSchema := requireSchema(t, schemas, "User") + requireSchemaFields(t, userSchema, "signup_source", "last_login_at", "last_active_at") +} + +func requireSchema(t *testing.T, schemas map[string]*load.Schema, name string) *load.Schema { + t.Helper() + + schema, ok := schemas[name] + require.True(t, ok, "schema %s should exist", name) + return schema +} + +func requireSchemaFields(t *testing.T, schema *load.Schema, names ...string) { + t.Helper() + + fields := map[string]struct{}{} + for _, field := range schema.Fields { + fields[field.Name] = struct{}{} + } + + for _, name := range names { + _, ok := fields[name] + require.True(t, ok, "schema %s should include field %s", schema.Name, name) + } +} + +func requireHasUniqueIndex(t *testing.T, schema *load.Schema, fields ...string) { + t.Helper() + + for _, index := range schema.Indexes { + if !index.Unique { + continue + } + if len(index.Fields) != len(fields) { + continue + } + match := true + for i := range fields { + if index.Fields[i] != fields[i] { + match = false + break + } + } + if match { + return + } + } + + require.Failf(t, "missing unique index", "schema %s should include unique index on %v", schema.Name, fields) +} diff --git a/backend/ent/schema/identity_adoption_decision.go b/backend/ent/schema/identity_adoption_decision.go new file mode 100644 index 00000000..9fdd26fb --- /dev/null +++ b/backend/ent/schema/identity_adoption_decision.go @@ -0,0 +1,70 @@ +package schema + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdentityAdoptionDecision stores the one-time profile adoption choice captured during a pending auth flow. +type IdentityAdoptionDecision struct { + ent.Schema +} + +func (IdentityAdoptionDecision) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "identity_adoption_decisions"}, + } +} + +func (IdentityAdoptionDecision) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdentityAdoptionDecision) Fields() []ent.Field { + return []ent.Field{ + field.Int64("pending_auth_session_id"), + field.Int64("identity_id"). + Optional(). + Nillable(), + field.Bool("adopt_display_name"). + Default(false), + field.Bool("adopt_avatar"). + Default(false), + field.Time("decided_at"). + Immutable(). + Default(time.Now). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (IdentityAdoptionDecision) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("pending_auth_session", PendingAuthSession.Type). + Ref("adoption_decision"). + Field("pending_auth_session_id"). + Required(). + Unique(), + edge.From("identity", AuthIdentity.Type). + Ref("adoption_decisions"). + Field("identity_id"). + Unique(), + } +} + +func (IdentityAdoptionDecision) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("pending_auth_session_id").Unique(), + index.Fields("identity_id"), + } +} diff --git a/backend/ent/schema/pending_auth_session.go b/backend/ent/schema/pending_auth_session.go new file mode 100644 index 00000000..91341d49 --- /dev/null +++ b/backend/ent/schema/pending_auth_session.go @@ -0,0 +1,134 @@ +package schema + +import ( + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +var pendingAuthIntents = map[string]struct{}{ + "login": {}, + "bind_current_user": {}, + "adopt_existing_user_by_email": {}, +} + +func validatePendingAuthIntent(value string) error { + if _, ok := pendingAuthIntents[value]; ok { + return nil + } + return fmt.Errorf("invalid pending auth intent %q", value) +} + +// PendingAuthSession stores a short-lived post-auth decision session. +type PendingAuthSession struct { + ent.Schema +} + +func (PendingAuthSession) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "pending_auth_sessions"}, + } +} + +func (PendingAuthSession) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (PendingAuthSession) Fields() []ent.Field { + return []ent.Field{ + field.String("session_token"). + MaxLen(255). + NotEmpty(), + field.String("intent"). + MaxLen(40). + NotEmpty(). + Validate(validatePendingAuthIntent), + field.String("provider_type"). + MaxLen(20). + NotEmpty(). + Validate(validateAuthProviderType), + field.String("provider_key"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("provider_subject"). + NotEmpty(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Int64("target_user_id"). + Optional(). + Nillable(), + field.String("redirect_to"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("resolved_email"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("registration_password_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.JSON("upstream_identity_claims", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.JSON("local_flow_state", map[string]any{}). + Default(func() map[string]any { return map[string]any{} }). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + field.String("browser_session_key"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.String("completion_code_hash"). + Default(""). + SchemaType(map[string]string{dialect.Postgres: "text"}), + field.Time("completion_code_expires_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("email_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("password_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("totp_verified_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("expires_at"). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("consumed_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + } +} + +func (PendingAuthSession) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("target_user", User.Type). + Ref("pending_auth_sessions"). + Field("target_user_id"). + Unique(), + edge.To("adoption_decision", IdentityAdoptionDecision.Type). + Unique(), + } +} + +func (PendingAuthSession) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("session_token").Unique(), + index.Fields("target_user_id"), + index.Fields("expires_at"), + index.Fields("provider_type", "provider_key", "provider_subject"), + index.Fields("completion_code_hash"), + } +} diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index ef52e985..bb58d9e3 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,17 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + field.String("signup_source"). + MaxLen(20). + Default("email"), + field.Time("last_login_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("last_active_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), // 余额不足通知 field.Bool("balance_notify_enabled"). @@ -104,6 +115,8 @@ func (User) Edges() []ent.Edge { edge.To("attribute_values", UserAttributeValue.Type), edge.To("promo_code_usages", PromoCodeUsage.Type), edge.To("payment_orders", PaymentOrder.Type), + edge.To("auth_identities", AuthIdentity.Type), + edge.To("pending_auth_sessions", PendingAuthSession.Type), } } diff --git a/backend/ent/tx.go b/backend/ent/tx.go index bb3139d5..bde3e35b 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,18 +24,26 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // AuthIdentity is the client for interacting with the AuthIdentity builders. + AuthIdentity *AuthIdentityClient + // AuthIdentityChannel is the client for interacting with the AuthIdentityChannel builders. + AuthIdentityChannel *AuthIdentityChannelClient // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. IdempotencyRecord *IdempotencyRecordClient + // IdentityAdoptionDecision is the client for interacting with the IdentityAdoptionDecision builders. + IdentityAdoptionDecision *IdentityAdoptionDecisionClient // PaymentAuditLog is the client for interacting with the PaymentAuditLog builders. PaymentAuditLog *PaymentAuditLogClient // PaymentOrder is the client for interacting with the PaymentOrder builders. PaymentOrder *PaymentOrderClient // PaymentProviderInstance is the client for interacting with the PaymentProviderInstance builders. PaymentProviderInstance *PaymentProviderInstanceClient + // PendingAuthSession is the client for interacting with the PendingAuthSession builders. + PendingAuthSession *PendingAuthSessionClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -202,12 +210,16 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.AuthIdentity = NewAuthIdentityClient(tx.config) + tx.AuthIdentityChannel = NewAuthIdentityChannelClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) + tx.IdentityAdoptionDecision = NewIdentityAdoptionDecisionClient(tx.config) tx.PaymentAuditLog = NewPaymentAuditLogClient(tx.config) tx.PaymentOrder = NewPaymentOrderClient(tx.config) tx.PaymentProviderInstance = NewPaymentProviderInstanceClient(tx.config) + tx.PendingAuthSession = NewPendingAuthSessionClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) diff --git a/backend/ent/user.go b/backend/ent/user.go index 9fa91f74..66f33623 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,12 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SignupSource holds the value of the "signup_source" field. + SignupSource string `json:"signup_source,omitempty"` + // LastLoginAt holds the value of the "last_login_at" field. + LastLoginAt *time.Time `json:"last_login_at,omitempty"` + // LastActiveAt holds the value of the "last_active_at" field. + LastActiveAt *time.Time `json:"last_active_at,omitempty"` // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. @@ -83,11 +89,15 @@ type UserEdges struct { PromoCodeUsages []*PromoCodeUsage `json:"promo_code_usages,omitempty"` // PaymentOrders holds the value of the payment_orders edge. PaymentOrders []*PaymentOrder `json:"payment_orders,omitempty"` + // AuthIdentities holds the value of the auth_identities edge. + AuthIdentities []*AuthIdentity `json:"auth_identities,omitempty"` + // PendingAuthSessions holds the value of the pending_auth_sessions edge. + PendingAuthSessions []*PendingAuthSession `json:"pending_auth_sessions,omitempty"` // UserAllowedGroups holds the value of the user_allowed_groups edge. UserAllowedGroups []*UserAllowedGroup `json:"user_allowed_groups,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [11]bool + loadedTypes [13]bool } // APIKeysOrErr returns the APIKeys value or an error if the edge @@ -180,10 +190,28 @@ func (e UserEdges) PaymentOrdersOrErr() ([]*PaymentOrder, error) { return nil, &NotLoadedError{edge: "payment_orders"} } +// AuthIdentitiesOrErr returns the AuthIdentities value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) AuthIdentitiesOrErr() ([]*AuthIdentity, error) { + if e.loadedTypes[10] { + return e.AuthIdentities, nil + } + return nil, &NotLoadedError{edge: "auth_identities"} +} + +// PendingAuthSessionsOrErr returns the PendingAuthSessions value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) PendingAuthSessionsOrErr() ([]*PendingAuthSession, error) { + if e.loadedTypes[11] { + return e.PendingAuthSessions, nil + } + return nil, &NotLoadedError{edge: "pending_auth_sessions"} +} + // UserAllowedGroupsOrErr returns the UserAllowedGroups value or an error if the edge // was not loaded in eager-loading. func (e UserEdges) UserAllowedGroupsOrErr() ([]*UserAllowedGroup, error) { - if e.loadedTypes[10] { + if e.loadedTypes[12] { return e.UserAllowedGroups, nil } return nil, &NotLoadedError{edge: "user_allowed_groups"} @@ -200,9 +228,9 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) - case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: + case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt, user.FieldLastLoginAt, user.FieldLastActiveAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -312,6 +340,26 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSignupSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field signup_source", values[i]) + } else if value.Valid { + _m.SignupSource = value.String + } + case user.FieldLastLoginAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_login_at", values[i]) + } else if value.Valid { + _m.LastLoginAt = new(time.Time) + *_m.LastLoginAt = value.Time + } + case user.FieldLastActiveAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_active_at", values[i]) + } else if value.Valid { + _m.LastActiveAt = new(time.Time) + *_m.LastActiveAt = value.Time + } case user.FieldBalanceNotifyEnabled: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i]) @@ -406,6 +454,16 @@ func (_m *User) QueryPaymentOrders() *PaymentOrderQuery { return NewUserClient(_m.config).QueryPaymentOrders(_m) } +// QueryAuthIdentities queries the "auth_identities" edge of the User entity. +func (_m *User) QueryAuthIdentities() *AuthIdentityQuery { + return NewUserClient(_m.config).QueryAuthIdentities(_m) +} + +// QueryPendingAuthSessions queries the "pending_auth_sessions" edge of the User entity. +func (_m *User) QueryPendingAuthSessions() *PendingAuthSessionQuery { + return NewUserClient(_m.config).QueryPendingAuthSessions(_m) +} + // QueryUserAllowedGroups queries the "user_allowed_groups" edge of the User entity. func (_m *User) QueryUserAllowedGroups() *UserAllowedGroupQuery { return NewUserClient(_m.config).QueryUserAllowedGroups(_m) @@ -482,6 +540,19 @@ func (_m *User) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + builder.WriteString("signup_source=") + builder.WriteString(_m.SignupSource) + builder.WriteString(", ") + if v := _m.LastLoginAt; v != nil { + builder.WriteString("last_login_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.LastActiveAt; v != nil { + builder.WriteString("last_active_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("balance_notify_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) builder.WriteString(", ") diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index d88a3a38..567e3b14 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,12 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSignupSource holds the string denoting the signup_source field in the database. + FieldSignupSource = "signup_source" + // FieldLastLoginAt holds the string denoting the last_login_at field in the database. + FieldLastLoginAt = "last_login_at" + // FieldLastActiveAt holds the string denoting the last_active_at field in the database. + FieldLastActiveAt = "last_active_at" // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. FieldBalanceNotifyEnabled = "balance_notify_enabled" // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. @@ -73,6 +79,10 @@ const ( EdgePromoCodeUsages = "promo_code_usages" // EdgePaymentOrders holds the string denoting the payment_orders edge name in mutations. EdgePaymentOrders = "payment_orders" + // EdgeAuthIdentities holds the string denoting the auth_identities edge name in mutations. + EdgeAuthIdentities = "auth_identities" + // EdgePendingAuthSessions holds the string denoting the pending_auth_sessions edge name in mutations. + EdgePendingAuthSessions = "pending_auth_sessions" // EdgeUserAllowedGroups holds the string denoting the user_allowed_groups edge name in mutations. EdgeUserAllowedGroups = "user_allowed_groups" // Table holds the table name of the user in the database. @@ -145,6 +155,20 @@ const ( PaymentOrdersInverseTable = "payment_orders" // PaymentOrdersColumn is the table column denoting the payment_orders relation/edge. PaymentOrdersColumn = "user_id" + // AuthIdentitiesTable is the table that holds the auth_identities relation/edge. + AuthIdentitiesTable = "auth_identities" + // AuthIdentitiesInverseTable is the table name for the AuthIdentity entity. + // It exists in this package in order to avoid circular dependency with the "authidentity" package. + AuthIdentitiesInverseTable = "auth_identities" + // AuthIdentitiesColumn is the table column denoting the auth_identities relation/edge. + AuthIdentitiesColumn = "user_id" + // PendingAuthSessionsTable is the table that holds the pending_auth_sessions relation/edge. + PendingAuthSessionsTable = "pending_auth_sessions" + // PendingAuthSessionsInverseTable is the table name for the PendingAuthSession entity. + // It exists in this package in order to avoid circular dependency with the "pendingauthsession" package. + PendingAuthSessionsInverseTable = "pending_auth_sessions" + // PendingAuthSessionsColumn is the table column denoting the pending_auth_sessions relation/edge. + PendingAuthSessionsColumn = "target_user_id" // UserAllowedGroupsTable is the table that holds the user_allowed_groups relation/edge. UserAllowedGroupsTable = "user_allowed_groups" // UserAllowedGroupsInverseTable is the table name for the UserAllowedGroup entity. @@ -171,6 +195,9 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSignupSource, + FieldLastLoginAt, + FieldLastActiveAt, FieldBalanceNotifyEnabled, FieldBalanceNotifyThresholdType, FieldBalanceNotifyThreshold, @@ -232,6 +259,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSignupSource holds the default value on creation for the "signup_source" field. + DefaultSignupSource string + // SignupSourceValidator is a validator for the "signup_source" field. It is called by the builders before save. + SignupSourceValidator func(string) error // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. DefaultBalanceNotifyEnabled bool // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. @@ -320,6 +351,21 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySignupSource orders the results by the signup_source field. +func BySignupSource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSignupSource, opts...).ToFunc() +} + +// ByLastLoginAt orders the results by the last_login_at field. +func ByLastLoginAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastLoginAt, opts...).ToFunc() +} + +// ByLastActiveAt orders the results by the last_active_at field. +func ByLastActiveAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastActiveAt, opts...).ToFunc() +} + // ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field. func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() @@ -485,6 +531,34 @@ func ByPaymentOrders(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { } } +// ByAuthIdentitiesCount orders the results by auth_identities count. +func ByAuthIdentitiesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAuthIdentitiesStep(), opts...) + } +} + +// ByAuthIdentities orders the results by auth_identities terms. +func ByAuthIdentities(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAuthIdentitiesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByPendingAuthSessionsCount orders the results by pending_auth_sessions count. +func ByPendingAuthSessionsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPendingAuthSessionsStep(), opts...) + } +} + +// ByPendingAuthSessions orders the results by pending_auth_sessions terms. +func ByPendingAuthSessions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPendingAuthSessionsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + // ByUserAllowedGroupsCount orders the results by user_allowed_groups count. func ByUserAllowedGroupsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -568,6 +642,20 @@ func newPaymentOrdersStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PaymentOrdersTable, PaymentOrdersColumn), ) } +func newAuthIdentitiesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AuthIdentitiesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) +} +func newPendingAuthSessionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PendingAuthSessionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) +} func newUserAllowedGroupsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 2788aa7a..cbcfcc26 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,21 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SignupSource applies equality check predicate on the "signup_source" field. It's identical to SignupSourceEQ. +func SignupSource(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// LastLoginAt applies equality check predicate on the "last_login_at" field. It's identical to LastLoginAtEQ. +func LastLoginAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastActiveAt applies equality check predicate on the "last_active_at" field. It's identical to LastActiveAtEQ. +func LastActiveAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + // BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ. func BalanceNotifyEnabled(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -885,6 +900,171 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SignupSourceEQ applies the EQ predicate on the "signup_source" field. +func SignupSourceEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldSignupSource, v)) +} + +// SignupSourceNEQ applies the NEQ predicate on the "signup_source" field. +func SignupSourceNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSignupSource, v)) +} + +// SignupSourceIn applies the In predicate on the "signup_source" field. +func SignupSourceIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldSignupSource, vs...)) +} + +// SignupSourceNotIn applies the NotIn predicate on the "signup_source" field. +func SignupSourceNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSignupSource, vs...)) +} + +// SignupSourceGT applies the GT predicate on the "signup_source" field. +func SignupSourceGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldSignupSource, v)) +} + +// SignupSourceGTE applies the GTE predicate on the "signup_source" field. +func SignupSourceGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldSignupSource, v)) +} + +// SignupSourceLT applies the LT predicate on the "signup_source" field. +func SignupSourceLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldSignupSource, v)) +} + +// SignupSourceLTE applies the LTE predicate on the "signup_source" field. +func SignupSourceLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldSignupSource, v)) +} + +// SignupSourceContains applies the Contains predicate on the "signup_source" field. +func SignupSourceContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldSignupSource, v)) +} + +// SignupSourceHasPrefix applies the HasPrefix predicate on the "signup_source" field. +func SignupSourceHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldSignupSource, v)) +} + +// SignupSourceHasSuffix applies the HasSuffix predicate on the "signup_source" field. +func SignupSourceHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldSignupSource, v)) +} + +// SignupSourceEqualFold applies the EqualFold predicate on the "signup_source" field. +func SignupSourceEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldSignupSource, v)) +} + +// SignupSourceContainsFold applies the ContainsFold predicate on the "signup_source" field. +func SignupSourceContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldSignupSource, v)) +} + +// LastLoginAtEQ applies the EQ predicate on the "last_login_at" field. +func LastLoginAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtNEQ applies the NEQ predicate on the "last_login_at" field. +func LastLoginAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastLoginAt, v)) +} + +// LastLoginAtIn applies the In predicate on the "last_login_at" field. +func LastLoginAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtNotIn applies the NotIn predicate on the "last_login_at" field. +func LastLoginAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastLoginAt, vs...)) +} + +// LastLoginAtGT applies the GT predicate on the "last_login_at" field. +func LastLoginAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastLoginAt, v)) +} + +// LastLoginAtGTE applies the GTE predicate on the "last_login_at" field. +func LastLoginAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastLoginAt, v)) +} + +// LastLoginAtLT applies the LT predicate on the "last_login_at" field. +func LastLoginAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastLoginAt, v)) +} + +// LastLoginAtLTE applies the LTE predicate on the "last_login_at" field. +func LastLoginAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastLoginAt, v)) +} + +// LastLoginAtIsNil applies the IsNil predicate on the "last_login_at" field. +func LastLoginAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastLoginAt)) +} + +// LastLoginAtNotNil applies the NotNil predicate on the "last_login_at" field. +func LastLoginAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastLoginAt)) +} + +// LastActiveAtEQ applies the EQ predicate on the "last_active_at" field. +func LastActiveAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtNEQ applies the NEQ predicate on the "last_active_at" field. +func LastActiveAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldLastActiveAt, v)) +} + +// LastActiveAtIn applies the In predicate on the "last_active_at" field. +func LastActiveAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtNotIn applies the NotIn predicate on the "last_active_at" field. +func LastActiveAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldLastActiveAt, vs...)) +} + +// LastActiveAtGT applies the GT predicate on the "last_active_at" field. +func LastActiveAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldLastActiveAt, v)) +} + +// LastActiveAtGTE applies the GTE predicate on the "last_active_at" field. +func LastActiveAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldLastActiveAt, v)) +} + +// LastActiveAtLT applies the LT predicate on the "last_active_at" field. +func LastActiveAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldLastActiveAt, v)) +} + +// LastActiveAtLTE applies the LTE predicate on the "last_active_at" field. +func LastActiveAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldLastActiveAt, v)) +} + +// LastActiveAtIsNil applies the IsNil predicate on the "last_active_at" field. +func LastActiveAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldLastActiveAt)) +} + +// LastActiveAtNotNil applies the NotNil predicate on the "last_active_at" field. +func LastActiveAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldLastActiveAt)) +} + // BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field. func BalanceNotifyEnabledEQ(v bool) predicate.User { return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) @@ -1345,6 +1525,52 @@ func HasPaymentOrdersWith(preds ...predicate.PaymentOrder) predicate.User { }) } +// HasAuthIdentities applies the HasEdge predicate on the "auth_identities" edge. +func HasAuthIdentities() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AuthIdentitiesTable, AuthIdentitiesColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasAuthIdentitiesWith applies the HasEdge predicate on the "auth_identities" edge with a given conditions (other predicates). +func HasAuthIdentitiesWith(preds ...predicate.AuthIdentity) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newAuthIdentitiesStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasPendingAuthSessions applies the HasEdge predicate on the "pending_auth_sessions" edge. +func HasPendingAuthSessions() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PendingAuthSessionsTable, PendingAuthSessionsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasPendingAuthSessionsWith applies the HasEdge predicate on the "pending_auth_sessions" edge with a given conditions (other predicates). +func HasPendingAuthSessionsWith(preds ...predicate.PendingAuthSession) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newPendingAuthSessionsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // HasUserAllowedGroups applies the HasEdge predicate on the "user_allowed_groups" edge. func HasUserAllowedGroups() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index fbc64f9c..db95e813 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -211,6 +213,48 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSignupSource sets the "signup_source" field. +func (_c *UserCreate) SetSignupSource(v string) *UserCreate { + _c.mutation.SetSignupSource(v) + return _c +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_c *UserCreate) SetNillableSignupSource(v *string) *UserCreate { + if v != nil { + _c.SetSignupSource(*v) + } + return _c +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_c *UserCreate) SetLastLoginAt(v time.Time) *UserCreate { + _c.mutation.SetLastLoginAt(v) + return _c +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastLoginAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastLoginAt(*v) + } + return _c +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_c *UserCreate) SetLastActiveAt(v time.Time) *UserCreate { + _c.mutation.SetLastActiveAt(v) + return _c +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableLastActiveAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetLastActiveAt(*v) + } + return _c +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate { _c.mutation.SetBalanceNotifyEnabled(v) @@ -431,6 +475,36 @@ func (_c *UserCreate) AddPaymentOrders(v ...*PaymentOrder) *UserCreate { return _c.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_c *UserCreate) AddAuthIdentityIDs(ids ...int64) *UserCreate { + _c.mutation.AddAuthIdentityIDs(ids...) + return _c +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_c *UserCreate) AddAuthIdentities(v ...*AuthIdentity) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_c *UserCreate) AddPendingAuthSessionIDs(ids ...int64) *UserCreate { + _c.mutation.AddPendingAuthSessionIDs(ids...) + return _c +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_c *UserCreate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserCreate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _c.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_c *UserCreate) Mutation() *UserMutation { return _c.mutation @@ -510,6 +584,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SignupSource(); !ok { + v := user.DefaultSignupSource + _c.mutation.SetSignupSource(v) + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { v := user.DefaultBalanceNotifyEnabled _c.mutation.SetBalanceNotifyEnabled(v) @@ -589,6 +667,14 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SignupSource(); !ok { + return &ValidationError{Name: "signup_source", err: errors.New(`ent: missing required field "User.signup_source"`)} + } + if v, ok := _c.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} } @@ -684,6 +770,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + _node.SignupSource = value + } + if value, ok := _c.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + _node.LastLoginAt = &value + } + if value, ok := _c.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + _node.LastActiveAt = &value + } if value, ok := _c.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) _node.BalanceNotifyEnabled = value @@ -868,6 +966,38 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := _c.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := _c.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -1106,6 +1236,54 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsert) SetSignupSource(v string) *UserUpsert { + u.Set(user.FieldSignupSource, v) + return u +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsert) UpdateSignupSource() *UserUpsert { + u.SetExcluded(user.FieldSignupSource) + return u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsert) SetLastLoginAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastLoginAt, v) + return u +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastLoginAt() *UserUpsert { + u.SetExcluded(user.FieldLastLoginAt) + return u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsert) ClearLastLoginAt() *UserUpsert { + u.SetNull(user.FieldLastLoginAt) + return u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsert) SetLastActiveAt(v time.Time) *UserUpsert { + u.Set(user.FieldLastActiveAt, v) + return u +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateLastActiveAt() *UserUpsert { + u.SetExcluded(user.FieldLastActiveAt) + return u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsert) ClearLastActiveAt() *UserUpsert { + u.SetNull(user.FieldLastActiveAt) + return u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert { u.Set(user.FieldBalanceNotifyEnabled, v) @@ -1446,6 +1624,62 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertOne) SetSignupSource(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSignupSource() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertOne) SetLastLoginAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertOne) ClearLastLoginAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertOne) SetLastActiveAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertOne) ClearLastActiveAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne { return u.Update(func(s *UserUpsert) { @@ -1965,6 +2199,62 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSignupSource sets the "signup_source" field. +func (u *UserUpsertBulk) SetSignupSource(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSignupSource(v) + }) +} + +// UpdateSignupSource sets the "signup_source" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSignupSource() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSignupSource() + }) +} + +// SetLastLoginAt sets the "last_login_at" field. +func (u *UserUpsertBulk) SetLastLoginAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastLoginAt(v) + }) +} + +// UpdateLastLoginAt sets the "last_login_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastLoginAt() + }) +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (u *UserUpsertBulk) ClearLastLoginAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastLoginAt() + }) +} + +// SetLastActiveAt sets the "last_active_at" field. +func (u *UserUpsertBulk) SetLastActiveAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetLastActiveAt(v) + }) +} + +// UpdateLastActiveAt sets the "last_active_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateLastActiveAt() + }) +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (u *UserUpsertBulk) ClearLastActiveAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearLastActiveAt() + }) +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk { return u.Update(func(s *UserUpsert) { diff --git a/backend/ent/user_query.go b/backend/ent/user_query.go index 113d87ac..f1ee5cfe 100644 --- a/backend/ent/user_query.go +++ b/backend/ent/user_query.go @@ -15,8 +15,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -44,6 +46,8 @@ type UserQuery struct { withAttributeValues *UserAttributeValueQuery withPromoCodeUsages *PromoCodeUsageQuery withPaymentOrders *PaymentOrderQuery + withAuthIdentities *AuthIdentityQuery + withPendingAuthSessions *PendingAuthSessionQuery withUserAllowedGroups *UserAllowedGroupQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). @@ -302,6 +306,50 @@ func (_q *UserQuery) QueryPaymentOrders() *PaymentOrderQuery { return query } +// QueryAuthIdentities chains the current query on the "auth_identities" edge. +func (_q *UserQuery) QueryAuthIdentities() *AuthIdentityQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(authidentity.Table, authidentity.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.AuthIdentitiesTable, user.AuthIdentitiesColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryPendingAuthSessions chains the current query on the "pending_auth_sessions" edge. +func (_q *UserQuery) QueryPendingAuthSessions() *PendingAuthSessionQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + selector := _q.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(pendingauthsession.Table, pendingauthsession.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.PendingAuthSessionsTable, user.PendingAuthSessionsColumn), + ) + fromU = sqlgraph.SetNeighbors(_q.driver.Dialect(), step) + return fromU, nil + } + return query +} + // QueryUserAllowedGroups chains the current query on the "user_allowed_groups" edge. func (_q *UserQuery) QueryUserAllowedGroups() *UserAllowedGroupQuery { query := (&UserAllowedGroupClient{config: _q.config}).Query() @@ -526,6 +574,8 @@ func (_q *UserQuery) Clone() *UserQuery { withAttributeValues: _q.withAttributeValues.Clone(), withPromoCodeUsages: _q.withPromoCodeUsages.Clone(), withPaymentOrders: _q.withPaymentOrders.Clone(), + withAuthIdentities: _q.withAuthIdentities.Clone(), + withPendingAuthSessions: _q.withPendingAuthSessions.Clone(), withUserAllowedGroups: _q.withUserAllowedGroups.Clone(), // clone intermediate query. sql: _q.sql.Clone(), @@ -643,6 +693,28 @@ func (_q *UserQuery) WithPaymentOrders(opts ...func(*PaymentOrderQuery)) *UserQu return _q } +// WithAuthIdentities tells the query-builder to eager-load the nodes that are connected to +// the "auth_identities" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithAuthIdentities(opts ...func(*AuthIdentityQuery)) *UserQuery { + query := (&AuthIdentityClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withAuthIdentities = query + return _q +} + +// WithPendingAuthSessions tells the query-builder to eager-load the nodes that are connected to +// the "pending_auth_sessions" edge. The optional arguments are used to configure the query builder of the edge. +func (_q *UserQuery) WithPendingAuthSessions(opts ...func(*PendingAuthSessionQuery)) *UserQuery { + query := (&PendingAuthSessionClient{config: _q.config}).Query() + for _, opt := range opts { + opt(query) + } + _q.withPendingAuthSessions = query + return _q +} + // WithUserAllowedGroups tells the query-builder to eager-load the nodes that are connected to // the "user_allowed_groups" edge. The optional arguments are used to configure the query builder of the edge. func (_q *UserQuery) WithUserAllowedGroups(opts ...func(*UserAllowedGroupQuery)) *UserQuery { @@ -732,7 +804,7 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = _q.querySpec() - loadedTypes = [11]bool{ + loadedTypes = [13]bool{ _q.withAPIKeys != nil, _q.withRedeemCodes != nil, _q.withSubscriptions != nil, @@ -743,6 +815,8 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e _q.withAttributeValues != nil, _q.withPromoCodeUsages != nil, _q.withPaymentOrders != nil, + _q.withAuthIdentities != nil, + _q.withPendingAuthSessions != nil, _q.withUserAllowedGroups != nil, } ) @@ -839,6 +913,22 @@ func (_q *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := _q.withAuthIdentities; query != nil { + if err := _q.loadAuthIdentities(ctx, query, nodes, + func(n *User) { n.Edges.AuthIdentities = []*AuthIdentity{} }, + func(n *User, e *AuthIdentity) { n.Edges.AuthIdentities = append(n.Edges.AuthIdentities, e) }); err != nil { + return nil, err + } + } + if query := _q.withPendingAuthSessions; query != nil { + if err := _q.loadPendingAuthSessions(ctx, query, nodes, + func(n *User) { n.Edges.PendingAuthSessions = []*PendingAuthSession{} }, + func(n *User, e *PendingAuthSession) { + n.Edges.PendingAuthSessions = append(n.Edges.PendingAuthSessions, e) + }); err != nil { + return nil, err + } + } if query := _q.withUserAllowedGroups; query != nil { if err := _q.loadUserAllowedGroups(ctx, query, nodes, func(n *User) { n.Edges.UserAllowedGroups = []*UserAllowedGroup{} }, @@ -1186,6 +1276,69 @@ func (_q *UserQuery) loadPaymentOrders(ctx context.Context, query *PaymentOrderQ } return nil } +func (_q *UserQuery) loadAuthIdentities(ctx context.Context, query *AuthIdentityQuery, nodes []*User, init func(*User), assign func(*User, *AuthIdentity)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(authidentity.FieldUserID) + } + query.Where(predicate.AuthIdentity(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.AuthIdentitiesColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} +func (_q *UserQuery) loadPendingAuthSessions(ctx context.Context, query *PendingAuthSessionQuery, nodes []*User, init func(*User), assign func(*User, *PendingAuthSession)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[int64]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(pendingauthsession.FieldTargetUserID) + } + query.Where(predicate.PendingAuthSession(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.PendingAuthSessionsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.TargetUserID + if fk == nil { + return fmt.Errorf(`foreign-key "target_user_id" is nil for node %v`, n.ID) + } + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "target_user_id" returned %v for node %v`, *fk, n.ID) + } + assign(node, n) + } + return nil +} func (_q *UserQuery) loadUserAllowedGroups(ctx context.Context, query *UserAllowedGroupQuery, nodes []*User, init func(*User), assign func(*User, *UserAllowedGroup)) error { fks := make([]driver.Value, 0, len(nodes)) nodeids := make(map[int64]*User) diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 6b355247..677eeb6b 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -13,8 +13,10 @@ import ( "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/redeemcode" @@ -243,6 +245,60 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdate) SetSignupSource(v string) *UserUpdate { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSignupSource(v *string) *UserUpdate { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdate) SetLastLoginAt(v time.Time) *UserUpdate { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastLoginAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdate) ClearLastLoginAt() *UserUpdate { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdate) SetLastActiveAt(v time.Time) *UserUpdate { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableLastActiveAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdate) ClearLastActiveAt() *UserUpdate { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate { _u.mutation.SetBalanceNotifyEnabled(v) @@ -483,6 +539,36 @@ func (_u *UserUpdate) AddPaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdate) AddAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) AddAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdate) AddPendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdate) Mutation() *UserMutation { return _u.mutation @@ -698,6 +784,48 @@ func (_u *UserUpdate) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdate { return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdate) ClearAuthIdentities() *UserUpdate { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdate) RemoveAuthIdentityIDs(ids ...int64) *UserUpdate { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdate) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdate) ClearPendingAuthSessions() *UserUpdate { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdate) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdate { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdate) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdate { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (_u *UserUpdate) Save(ctx context.Context) (int, error) { if err := _u.defaults(); err != nil { @@ -767,6 +895,11 @@ func (_u *UserUpdate) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -836,6 +969,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -1322,6 +1470,96 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -1548,6 +1786,60 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSignupSource sets the "signup_source" field. +func (_u *UserUpdateOne) SetSignupSource(v string) *UserUpdateOne { + _u.mutation.SetSignupSource(v) + return _u +} + +// SetNillableSignupSource sets the "signup_source" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSignupSource(v *string) *UserUpdateOne { + if v != nil { + _u.SetSignupSource(*v) + } + return _u +} + +// SetLastLoginAt sets the "last_login_at" field. +func (_u *UserUpdateOne) SetLastLoginAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastLoginAt(v) + return _u +} + +// SetNillableLastLoginAt sets the "last_login_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastLoginAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastLoginAt(*v) + } + return _u +} + +// ClearLastLoginAt clears the value of the "last_login_at" field. +func (_u *UserUpdateOne) ClearLastLoginAt() *UserUpdateOne { + _u.mutation.ClearLastLoginAt() + return _u +} + +// SetLastActiveAt sets the "last_active_at" field. +func (_u *UserUpdateOne) SetLastActiveAt(v time.Time) *UserUpdateOne { + _u.mutation.SetLastActiveAt(v) + return _u +} + +// SetNillableLastActiveAt sets the "last_active_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableLastActiveAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetLastActiveAt(*v) + } + return _u +} + +// ClearLastActiveAt clears the value of the "last_active_at" field. +func (_u *UserUpdateOne) ClearLastActiveAt() *UserUpdateOne { + _u.mutation.ClearLastActiveAt() + return _u +} + // SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne { _u.mutation.SetBalanceNotifyEnabled(v) @@ -1788,6 +2080,36 @@ func (_u *UserUpdateOne) AddPaymentOrders(v ...*PaymentOrder) *UserUpdateOne { return _u.AddPaymentOrderIDs(ids...) } +// AddAuthIdentityIDs adds the "auth_identities" edge to the AuthIdentity entity by IDs. +func (_u *UserUpdateOne) AddAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddAuthIdentityIDs(ids...) + return _u +} + +// AddAuthIdentities adds the "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) AddAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddAuthIdentityIDs(ids...) +} + +// AddPendingAuthSessionIDs adds the "pending_auth_sessions" edge to the PendingAuthSession entity by IDs. +func (_u *UserUpdateOne) AddPendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.AddPendingAuthSessionIDs(ids...) + return _u +} + +// AddPendingAuthSessions adds the "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) AddPendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.AddPendingAuthSessionIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (_u *UserUpdateOne) Mutation() *UserMutation { return _u.mutation @@ -2003,6 +2325,48 @@ func (_u *UserUpdateOne) RemovePaymentOrders(v ...*PaymentOrder) *UserUpdateOne return _u.RemovePaymentOrderIDs(ids...) } +// ClearAuthIdentities clears all "auth_identities" edges to the AuthIdentity entity. +func (_u *UserUpdateOne) ClearAuthIdentities() *UserUpdateOne { + _u.mutation.ClearAuthIdentities() + return _u +} + +// RemoveAuthIdentityIDs removes the "auth_identities" edge to AuthIdentity entities by IDs. +func (_u *UserUpdateOne) RemoveAuthIdentityIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemoveAuthIdentityIDs(ids...) + return _u +} + +// RemoveAuthIdentities removes "auth_identities" edges to AuthIdentity entities. +func (_u *UserUpdateOne) RemoveAuthIdentities(v ...*AuthIdentity) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemoveAuthIdentityIDs(ids...) +} + +// ClearPendingAuthSessions clears all "pending_auth_sessions" edges to the PendingAuthSession entity. +func (_u *UserUpdateOne) ClearPendingAuthSessions() *UserUpdateOne { + _u.mutation.ClearPendingAuthSessions() + return _u +} + +// RemovePendingAuthSessionIDs removes the "pending_auth_sessions" edge to PendingAuthSession entities by IDs. +func (_u *UserUpdateOne) RemovePendingAuthSessionIDs(ids ...int64) *UserUpdateOne { + _u.mutation.RemovePendingAuthSessionIDs(ids...) + return _u +} + +// RemovePendingAuthSessions removes "pending_auth_sessions" edges to PendingAuthSession entities. +func (_u *UserUpdateOne) RemovePendingAuthSessions(v ...*PendingAuthSession) *UserUpdateOne { + ids := make([]int64, len(v)) + for i := range v { + ids[i] = v[i].ID + } + return _u.RemovePendingAuthSessionIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (_u *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { _u.mutation.Where(ps...) @@ -2085,6 +2449,11 @@ func (_u *UserUpdateOne) check() error { return &ValidationError{Name: "username", err: fmt.Errorf(`ent: validator failed for field "User.username": %w`, err)} } } + if v, ok := _u.mutation.SignupSource(); ok { + if err := user.SignupSourceValidator(v); err != nil { + return &ValidationError{Name: "signup_source", err: fmt.Errorf(`ent: validator failed for field "User.signup_source": %w`, err)} + } + } return nil } @@ -2171,6 +2540,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SignupSource(); ok { + _spec.SetField(user.FieldSignupSource, field.TypeString, value) + } + if value, ok := _u.mutation.LastLoginAt(); ok { + _spec.SetField(user.FieldLastLoginAt, field.TypeTime, value) + } + if _u.mutation.LastLoginAtCleared() { + _spec.ClearField(user.FieldLastLoginAt, field.TypeTime) + } + if value, ok := _u.mutation.LastActiveAt(); ok { + _spec.SetField(user.FieldLastActiveAt, field.TypeTime, value) + } + if _u.mutation.LastActiveAtCleared() { + _spec.ClearField(user.FieldLastActiveAt, field.TypeTime) + } if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) } @@ -2657,6 +3041,96 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if _u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedAuthIdentitiesIDs(); len(nodes) > 0 && !_u.mutation.AuthIdentitiesCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.AuthIdentitiesIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.AuthIdentitiesTable, + Columns: []string{user.AuthIdentitiesColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(authidentity.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if _u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.RemovedPendingAuthSessionsIDs(); len(nodes) > 0 && !_u.mutation.PendingAuthSessionsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := _u.mutation.PendingAuthSessionsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.PendingAuthSessionsTable, + Columns: []string{user.PendingAuthSessionsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(pendingauthsession.FieldID, field.TypeInt64), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } _node = &User{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index dd9a4e58..6136e9ea 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1608,6 +1608,9 @@ func (c *Config) Validate() error { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } if c.LinuxDo.Enabled { + if !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.enabled=true") + } if strings.TrimSpace(c.LinuxDo.ClientID) == "" { return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") } @@ -1629,9 +1632,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.LinuxDo.UsePKCE { - return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") @@ -1663,6 +1663,12 @@ func (c *Config) Validate() error { warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) } if c.OIDC.Enabled { + if !c.OIDC.UsePKCE { + return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.enabled=true") + } + if !c.OIDC.ValidateIDToken { + return fmt.Errorf("oidc_connect.validate_id_token must be true when oidc_connect.enabled=true") + } if strings.TrimSpace(c.OIDC.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") } @@ -1685,9 +1691,6 @@ func (c *Config) Validate() error { default: return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") } - if method == "none" && !c.OIDC.UsePKCE { - return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none") - } if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.OIDC.ClientSecret) == "" { return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index bec0f126..fe5c7928 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -73,6 +73,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) @@ -93,7 +98,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { paymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ + payload := dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, @@ -200,7 +205,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, - }) + } + response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } // UpdateSettingsRequest 更新设置请求 @@ -276,9 +282,30 @@ type UpdateSettingsRequest struct { CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` + AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` + AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"` + AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"` + AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"` + AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"` + AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"` + AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"` + AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"` + AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"` + AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"` + AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"` + AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"` + AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"` + AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"` + AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"` + AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"` + AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` + AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` + AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` + ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -357,6 +384,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + previousAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } // 验证参数 if req.DefaultConcurrency < 1 { @@ -381,6 +413,10 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.SMTPPort = 587 } req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) + req.AuthSourceDefaultEmailSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultEmailSubscriptions) + req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions) + req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions) + req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions) // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 @@ -538,25 +574,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "OIDC scopes must contain openid") return } + if !req.OIDCConnectUsePKCE { + response.BadRequest(c, "OIDC PKCE must be enabled") + return + } + if !req.OIDCConnectValidateIDToken { + response.BadRequest(c, "OIDC ID Token validation must be enabled") + return + } switch req.OIDCConnectTokenAuthMethod { case "", "client_secret_post", "client_secret_basic", "none": default: response.BadRequest(c, "OIDC Token Auth Method must be one of client_secret_post/client_secret_basic/none") return } - if req.OIDCConnectTokenAuthMethod == "none" && !req.OIDCConnectUsePKCE { - response.BadRequest(c, "OIDC PKCE must be enabled when token_auth_method=none") - return - } if req.OIDCConnectClockSkewSeconds < 0 || req.OIDCConnectClockSkewSeconds > 600 { response.BadRequest(c, "OIDC clock skew seconds must be between 0 and 600") return } - if req.OIDCConnectValidateIDToken { - if req.OIDCConnectAllowedSigningAlgs == "" { - response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") - return - } + if req.OIDCConnectAllowedSigningAlgs == "" { + response.BadRequest(c, "OIDC Allowed Signing Algs is required when validate_id_token=true") + return } if req.OIDCConnectJWKSURL != "" { if err := config.ValidateAbsoluteHTTPURL(req.OIDCConnectJWKSURL); err != nil { @@ -933,6 +971,41 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + authSourceDefaults := &service.AuthSourceDefaultSettings{ + Email: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultEmailBalance, previousAuthSourceDefaults.Email.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultEmailConcurrency, previousAuthSourceDefaults.Email.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultEmailSubscriptions, previousAuthSourceDefaults.Email.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnSignup, previousAuthSourceDefaults.Email.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultEmailGrantOnFirstBind, previousAuthSourceDefaults.Email.GrantOnFirstBind), + }, + LinuxDo: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultLinuxDoBalance, previousAuthSourceDefaults.LinuxDo.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultLinuxDoConcurrency, previousAuthSourceDefaults.LinuxDo.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultLinuxDoSubscriptions, previousAuthSourceDefaults.LinuxDo.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnSignup, previousAuthSourceDefaults.LinuxDo.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultLinuxDoGrantOnFirstBind, previousAuthSourceDefaults.LinuxDo.GrantOnFirstBind), + }, + OIDC: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultOIDCBalance, previousAuthSourceDefaults.OIDC.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultOIDCConcurrency, previousAuthSourceDefaults.OIDC.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultOIDCSubscriptions, previousAuthSourceDefaults.OIDC.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnSignup, previousAuthSourceDefaults.OIDC.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultOIDCGrantOnFirstBind, previousAuthSourceDefaults.OIDC.GrantOnFirstBind), + }, + WeChat: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultWeChatBalance, previousAuthSourceDefaults.WeChat.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultWeChatConcurrency, previousAuthSourceDefaults.WeChat.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultWeChatSubscriptions, previousAuthSourceDefaults.WeChat.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnSignup, previousAuthSourceDefaults.WeChat.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultWeChatGrantOnFirstBind, previousAuthSourceDefaults.WeChat.GrantOnFirstBind), + }, + ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), + } + if err := h.settingService.UpdateAuthSourceDefaultSettings(c.Request.Context(), authSourceDefaults); err != nil { + response.ErrorFrom(c, err) + return + } // Update payment configuration (integrated into system settings). // Skip if no payment fields were provided (prevents accidental wipe). @@ -977,6 +1050,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) for _, sub := range updatedSettings.DefaultSubscriptions { updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ @@ -994,7 +1072,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { updatedPaymentCfg = &service.PaymentConfig{} } - response.Success(c, dto.SystemSettings{ + payload := dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, RegistrationEmailSuffixWhitelist: updatedSettings.RegistrationEmailSuffixWhitelist, @@ -1100,7 +1178,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, - }) + } + response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } // hasPaymentFields returns true if any payment-related field was explicitly provided. @@ -1412,6 +1491,84 @@ func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto return normalized } +func normalizeOptionalDefaultSubscriptions(input *[]dto.DefaultSubscriptionSetting) *[]dto.DefaultSubscriptionSetting { + if input == nil { + return nil + } + normalized := normalizeDefaultSubscriptions(*input) + return &normalized +} + +func float64ValueOrDefault(value *float64, fallback float64) float64 { + if value == nil { + return fallback + } + return *value +} + +func intValueOrDefault(value *int, fallback int) int { + if value == nil { + return fallback + } + return *value +} + +func boolValueOrDefault(value *bool, fallback bool) bool { + if value == nil { + return fallback + } + return *value +} + +func defaultSubscriptionsValueOrDefault(input *[]dto.DefaultSubscriptionSetting, fallback []service.DefaultSubscriptionSetting) []service.DefaultSubscriptionSetting { + if input == nil { + return fallback + } + result := make([]service.DefaultSubscriptionSetting, 0, len(*input)) + for _, item := range *input { + result = append(result, service.DefaultSubscriptionSetting{ + GroupID: item.GroupID, + ValidityDays: item.ValidityDays, + }) + } + return result +} + +func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults *service.AuthSourceDefaultSettings) map[string]any { + data := make(map[string]any) + raw, err := json.Marshal(settings) + if err == nil { + _ = json.Unmarshal(raw, &data) + } + if authSourceDefaults == nil { + authSourceDefaults = &service.AuthSourceDefaultSettings{} + } + + data["auth_source_default_email_balance"] = authSourceDefaults.Email.Balance + data["auth_source_default_email_concurrency"] = authSourceDefaults.Email.Concurrency + data["auth_source_default_email_subscriptions"] = authSourceDefaults.Email.Subscriptions + data["auth_source_default_email_grant_on_signup"] = authSourceDefaults.Email.GrantOnSignup + data["auth_source_default_email_grant_on_first_bind"] = authSourceDefaults.Email.GrantOnFirstBind + data["auth_source_default_linuxdo_balance"] = authSourceDefaults.LinuxDo.Balance + data["auth_source_default_linuxdo_concurrency"] = authSourceDefaults.LinuxDo.Concurrency + data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions + data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup + data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind + data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance + data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency + data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions + data["auth_source_default_oidc_grant_on_signup"] = authSourceDefaults.OIDC.GrantOnSignup + data["auth_source_default_oidc_grant_on_first_bind"] = authSourceDefaults.OIDC.GrantOnFirstBind + data["auth_source_default_wechat_balance"] = authSourceDefaults.WeChat.Balance + data["auth_source_default_wechat_concurrency"] = authSourceDefaults.WeChat.Concurrency + data["auth_source_default_wechat_subscriptions"] = authSourceDefaults.WeChat.Subscriptions + data["auth_source_default_wechat_grant_on_signup"] = authSourceDefaults.WeChat.GrantOnSignup + data["auth_source_default_wechat_grant_on_first_bind"] = authSourceDefaults.WeChat.GrantOnFirstBind + data["force_email_on_third_party_signup"] = authSourceDefaults.ForceEmailOnThirdPartySignup + + return data +} + func equalStringSlice(a, b []string) bool { if len(a) != len(b) { return false diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go new file mode 100644 index 00000000..b26fa447 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -0,0 +1,149 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerRepoStub struct { + values map[string]string + lastUpdates map[string]string +} + +func (s *settingHandlerRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.lastUpdates = make(map[string]string, len(settings)) + for key, value := range settings { + s.lastUpdates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *settingHandlerRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *settingHandlerRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/admin/settings", nil) + + handler.GetSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 9.5, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) + + subscriptions, ok := data["auth_source_default_email_subscriptions"].([]any) + require.True(t, ok) + require.Len(t, subscriptions, 1) +} + +func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "false", + service.SettingKeyPromoCodeEnabled: "true", + service.SettingKeyAuthSourceDefaultEmailBalance: "9.5", + service.SettingKeyAuthSourceDefaultEmailConcurrency: "8", + service.SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":31,"validity_days":15}]`, + service.SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + service.SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "registration_enabled": true, + "promo_code_enabled": true, + "auth_source_default_email_balance": 12.75, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "12.75000000", repo.values[service.SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "8", repo.values[service.SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, `[{"group_id":31,"validity_days":15}]`, repo.values[service.SettingKeyAuthSourceDefaultEmailSubscriptions]) + require.Equal(t, "true", repo.values[service.SettingKeyForceEmailOnThirdPartySignup]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, 12.75, data["auth_source_default_email_balance"]) + require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) + require.Equal(t, true, data["force_email_on_third_party_signup"]) +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 2f182642..b0edcf5a 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -219,7 +219,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ @@ -262,6 +262,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { ProviderKey: "linuxdo", ProviderSubject: subject, }, + TargetUserID: &user.ID, ResolvedEmail: email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -287,7 +288,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } type completeLinuxDoOAuthRequest struct { - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteLinuxDoOAuthRegistration completes a pending OAuth registration by validating @@ -335,11 +338,23 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 90bc10d1..661c0da0 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -1,10 +1,21 @@ package handler import ( + "bytes" + "context" + "net/http" + "net/http/httptest" "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -110,3 +121,79 @@ func TestSingleLineStripsWhitespace(t *testing.T) { require.Equal(t, "hello world", singleLine("hello\r\nworld")) require.Equal(t, "", singleLine("\n\t\r")) } + +func TestCompleteLinuxDoOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("linuxdo-complete-session"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("linuxdo-subject-1"). + SetResolvedEmail("linuxdo-subject-1@linuxdo-connect.invalid"). + SetBrowserSessionKey("linuxdo-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "LinuxDo Display", + "suggested_avatar_url": "https://cdn.example/linuxdo.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/linuxdo/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("linuxdo-browser")}) + c.Request = req + + handler.CompleteLinuxDoOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "LinuxDo Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("linuxdo-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "LinuxDo Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/linuxdo.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index a758c0b9..da8ac858 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -1,10 +1,17 @@ package handler import ( + "context" + "errors" + "io" "net/http" "net/url" "strings" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -26,6 +33,7 @@ const ( type oauthPendingSessionPayload struct { Intent string Identity service.PendingAuthIdentityKey + TargetUserID *int64 ResolvedEmail string RedirectTo string BrowserSessionKey string @@ -33,6 +41,11 @@ type oauthPendingSessionPayload struct { CompletionResponse map[string]any } +type oauthAdoptionDecisionRequest struct { + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { if h == nil || h.authService == nil || h.authService.EntClient() == nil { return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") @@ -125,6 +138,7 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen session, err := svc.CreatePendingSession(c.Request.Context(), service.CreatePendingAuthSessionInput{ Intent: strings.TrimSpace(payload.Intent), Identity: payload.Identity, + TargetUserID: payload.TargetUserID, ResolvedEmail: strings.TrimSpace(payload.ResolvedEmail), RedirectTo: strings.TrimSpace(payload.RedirectTo), BrowserSessionKey: strings.TrimSpace(payload.BrowserSessionKey), @@ -175,6 +189,291 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") } +func (r oauthAdoptionDecisionRequest) hasDecision() bool { + return r.AdoptDisplayName != nil || r.AdoptAvatar != nil +} + +func (r oauthAdoptionDecisionRequest) toServiceInput(sessionID int64) service.PendingIdentityAdoptionDecisionInput { + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if r.AdoptDisplayName != nil { + input.AdoptDisplayName = *r.AdoptDisplayName + } + if r.AdoptAvatar != nil { + input.AdoptAvatar = *r.AdoptAvatar + } + return input +} + +func bindOptionalOAuthAdoptionDecision(c *gin.Context) (oauthAdoptionDecisionRequest, error) { + var req oauthAdoptionDecisionRequest + if c == nil || c.Request == nil || c.Request.Body == nil { + return req, nil + } + if err := c.ShouldBindJSON(&req); err != nil { + if errors.Is(err, io.EOF) { + return req, nil + } + return req, err + } + return req, nil +} + +func persistPendingOAuthAdoptionDecision( + c *gin.Context, + svc *service.AuthPendingIdentityService, + sessionID int64, + req oauthAdoptionDecisionRequest, +) error { + if !req.hasDecision() { + return nil + } + if svc == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + if _, err := svc.UpsertAdoptionDecision(c.Request.Context(), req.toServiceInput(sessionID)); err != nil { + return infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return nil +} + +func cloneOAuthMetadata(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func normalizeAdoptedOAuthDisplayName(value string) string { + value = strings.TrimSpace(value) + if len([]rune(value)) > 100 { + value = string([]rune(value)[:100]) + } + return value +} + +func (h *AuthHandler) entClient() *dbent.Client { + if h == nil || h.authService == nil { + return nil + } + return h.authService.EntClient() +} + +func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + existing, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(sessionID)). + Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_LOAD_FAILED", "failed to load oauth profile adoption decision").WithCause(err) + } + if existing != nil && !req.hasDecision() { + return existing, nil + } + if existing == nil && !req.hasDecision() { + return nil, nil + } + + input := service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + } + if existing != nil { + input.AdoptDisplayName = existing.AdoptDisplayName + input.AdoptAvatar = existing.AdoptAvatar + input.IdentityID = existing.IdentityID + } + if req.AdoptDisplayName != nil { + input.AdoptDisplayName = *req.AdoptDisplayName + } + if req.AdoptAvatar != nil { + input.AdoptAvatar = *req.AdoptAvatar + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err := svc.UpsertAdoptionDecision(c.Request.Context(), input) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { + if session == nil { + return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + if session.TargetUserID != nil && *session.TargetUserID > 0 { + return *session.TargetUserID, nil + } + email := strings.TrimSpace(session.ResolvedEmail) + if email == "" { + return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing") + } + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(email)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found") + } + return 0, err + } + return userEntity.ID, nil +} + +func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { + if session == nil { + return nil + } + switch strings.TrimSpace(session.ProviderType) { + case "oidc": + issuer := strings.TrimSpace(session.ProviderKey) + if issuer == "" { + issuer = pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + } + if issuer == "" { + return nil + } + return &issuer + default: + issuer := pendingSessionStringValue(session.UpstreamIdentityClaims, "issuer") + if issuer == "" { + return nil + } + return &issuer + } +} + +func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(session.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(session.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(session.ProviderSubject)), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil { + if identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + return identity, nil + } + + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(strings.TrimSpace(session.ProviderType)). + SetProviderKey(strings.TrimSpace(session.ProviderKey)). + SetProviderSubject(strings.TrimSpace(session.ProviderSubject)). + SetMetadata(cloneOAuthMetadata(session.UpstreamIdentityClaims)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + return create.Save(ctx) +} + +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + if client == nil || session == nil || decision == nil { + return nil + } + if !decision.AdoptDisplayName && !decision.AdoptAvatar { + return nil + } + + targetUserID := int64(0) + if overrideUserID != nil && *overrideUserID > 0 { + targetUserID = *overrideUserID + } else { + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + if err != nil { + return err + } + targetUserID = resolvedUserID + } + + adoptedDisplayName := "" + if decision.AdoptDisplayName { + adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) + } + adoptedAvatarURL := "" + if decision.AdoptAvatar { + adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") + } + + tx, err := client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + if decision.AdoptDisplayName && adoptedDisplayName != "" { + if err := tx.Client().User.UpdateOneID(targetUserID). + SetUsername(adoptedDisplayName). + Exec(ctx); err != nil { + return err + } + } + + identity, err := ensurePendingOAuthIdentityForUser(ctx, tx, session, targetUserID) + if err != nil { + return err + } + + metadata := cloneOAuthMetadata(identity.Metadata) + for key, value := range session.UpstreamIdentityClaims { + metadata[key] = value + } + if decision.AdoptDisplayName && adoptedDisplayName != "" { + metadata["display_name"] = adoptedDisplayName + } + if decision.AdoptAvatar && adoptedAvatarURL != "" { + metadata["avatar_url"] = adoptedAvatarURL + } + + updateIdentity := tx.Client().AuthIdentity.UpdateOneID(identity.ID).SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + updateIdentity = updateIdentity.SetIssuer(strings.TrimSpace(*issuer)) + } + if _, err := updateIdentity.Save(ctx); err != nil { + return err + } + + if decision.IdentityID == nil || *decision.IdentityID != identity.ID { + if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). + SetIdentityID(identity.ID). + Save(ctx); err != nil { + return err + } + } + + return tx.Commit() +} + func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { if len(payload) == 0 || len(upstream) == 0 { return @@ -206,6 +505,11 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) } + adoptionDecision, err := bindOptionalOAuthAdoptionDecision(c) + if err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } sessionToken, err := readOAuthPendingSessionCookie(c) if err != nil || strings.TrimSpace(sessionToken) == "" { @@ -248,9 +552,30 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { applySuggestedProfileToCompletionResponse(payload, session.UpstreamIdentityClaims) if pendingSessionWantsInvitation(payload) { + if adoptionDecision.hasDecision() { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + _ = decision + } + response.Success(c, payload) + return + } + if !adoptionDecision.hasDecision() { response.Success(c, payload) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, adoptionDecision) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, session.TargetUserID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := svc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearCookies() diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 5517bae2..829fc217 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -1,9 +1,30 @@ package handler import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "net/http" + "net/http/httptest" "testing" + "time" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestApplySuggestedProfileToCompletionResponse(t *testing.T) { @@ -38,3 +59,439 @@ func TestApplySuggestedProfileToCompletionResponseKeepsExistingPayloadValues(t * require.Equal(t, "https://cdn.example/avatar.png", payload["suggested_avatar_url"]) require.Equal(t, true, payload["adoption_required"]) } + +func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + userEntity, err := client.User.Create(). + SetEmail("linuxdo-123@linuxdo-connect.invalid"). + SetUsername("legacy-name"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("pending-session-token"). + SetIntent("login"). + SetProviderType("linuxdo"). + SetProviderKey("linuxdo"). + SetProviderSubject("123"). + SetTargetUserID(userEntity.ID). + SetResolvedEmail(userEntity.Email). + SetBrowserSessionKey("browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "linuxdo_user", + "suggested_display_name": "Alice Example", + "suggested_avatar_url": "https://cdn.example/alice.png", + }). + SetLocalFlowState(map[string]any{ + oauthCompletionResponseKey: map[string]any{ + "access_token": "access-token", + "redirect": "/dashboard", + }, + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + previewRecorder := httptest.NewRecorder() + previewCtx, _ := gin.CreateTestContext(previewRecorder) + previewReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", nil) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + previewReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + previewCtx.Request = previewReq + + handler.ExchangePendingOAuthCompletion(previewCtx) + + require.Equal(t, http.StatusOK, previewRecorder.Code) + previewData := decodeJSONResponseData(t, previewRecorder) + require.Equal(t, "Alice Example", previewData["suggested_display_name"]) + require.Equal(t, "https://cdn.example/alice.png", previewData["suggested_avatar_url"]) + require.Equal(t, true, previewData["adoption_required"]) + + storedUser, err := client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "legacy-name", storedUser.Username) + + previewSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.Nil(t, previewSession.ConsumedAt) + + body := bytes.NewBufferString(`{"adopt_display_name":true,"adopt_avatar":true}`) + finalizeRecorder := httptest.NewRecorder() + finalizeCtx, _ := gin.CreateTestContext(finalizeRecorder) + finalizeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/pending/exchange", body) + finalizeReq.Header.Set("Content-Type", "application/json") + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + finalizeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-session-key")}) + finalizeCtx.Request = finalizeReq + + handler.ExchangePendingOAuthCompletion(finalizeCtx) + + require.Equal(t, http.StatusOK, finalizeRecorder.Code) + + storedUser, err = client.User.Get(ctx, userEntity.ID) + require.NoError(t, err) + require.Equal(t, "Alice Example", storedUser.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("linuxdo"), + authidentity.ProviderKeyEQ("linuxdo"), + authidentity.ProviderSubjectEQ("123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "Alice Example", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + } + settingSvc := service.NewSettingService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, cfg) + authSvc := service.NewAuthService( + client, + &oauthPendingFlowUserRepo{client: client}, + nil, + &oauthPendingFlowRefreshTokenCacheStub{}, + cfg, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + }, client +} + +func boolSettingValue(v bool) string { + if v { + return "true" + } + return "false" +} + +func boolPtr(v bool) *bool { + return &v +} + +type oauthPendingFlowSettingRepoStub struct { + values map[string]string +} + +func (s *oauthPendingFlowSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *oauthPendingFlowSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *oauthPendingFlowSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type oauthPendingFlowRefreshTokenCacheStub struct{} + +func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *oauthPendingFlowRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + +func decodeJSONResponseData(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var envelope struct { + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &envelope)) + return envelope.Data +} + +func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[string]any { + t.Helper() + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + return payload +} + +type oauthPendingFlowUserRepo struct { + client *dbent.Client +} + +func (r *oauthPendingFlowUserRepo) Create(ctx context.Context, user *service.User) error { + entity, err := r.client.User.Create(). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.ID = entity.ID + user.CreatedAt = entity.CreatedAt + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + entity, err := r.client.User.Get(ctx, id) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + entity, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrUserNotFound + } + return nil, err + } + return oauthPendingFlowServiceUser(entity), nil +} + +func (r *oauthPendingFlowUserRepo) GetFirstAdmin(context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (r *oauthPendingFlowUserRepo) Update(ctx context.Context, user *service.User) error { + entity, err := r.client.User.UpdateOneID(user.ID). + SetEmail(user.Email). + SetUsername(user.Username). + SetNotes(user.Notes). + SetPasswordHash(user.PasswordHash). + SetRole(user.Role). + SetBalance(user.Balance). + SetConcurrency(user.Concurrency). + SetStatus(user.Status). + SetSignupSource(user.SignupSource). + SetNillableLastLoginAt(user.LastLoginAt). + SetNillableLastActiveAt(user.LastActiveAt). + Save(ctx) + if err != nil { + return err + } + user.UpdatedAt = entity.UpdatedAt + return nil +} + +func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { + return r.client.User.DeleteOneID(id).Exec(ctx) +} + +func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + return nil, service.ErrUserNotFound +} + +func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error { + return nil +} + +func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (r *oauthPendingFlowUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (r *oauthPendingFlowUserRepo) UpdateBalance(context.Context, int64, float64) error { + panic("unexpected UpdateBalance call") +} + +func (r *oauthPendingFlowUserRepo) DeductBalance(context.Context, int64, float64) error { + panic("unexpected DeductBalance call") +} + +func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int) error { + panic("unexpected UpdateConcurrency call") +} + +func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) + return count > 0, err +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { + panic("unexpected AddGroupToAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected RemoveGroupFromUserAllowedGroups call") +} + +func (r *oauthPendingFlowUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (r *oauthPendingFlowUserRepo) EnableTotp(context.Context, int64) error { + panic("unexpected EnableTotp call") +} + +func (r *oauthPendingFlowUserRepo) DisableTotp(context.Context, int64) error { + panic("unexpected DisableTotp call") +} + +func oauthPendingFlowServiceUser(entity *dbent.User) *service.User { + if entity == nil { + return nil + } + return &service.User{ + ID: entity.ID, + Email: entity.Email, + Username: entity.Username, + Notes: entity.Notes, + PasswordHash: entity.PasswordHash, + Role: entity.Role, + Balance: entity.Balance, + Concurrency: entity.Concurrency, + Status: entity.Status, + SignupSource: entity.SignupSource, + LastLoginAt: entity.LastLoginAt, + LastActiveAt: entity.LastActiveAt, + CreatedAt: entity.CreatedAt, + UpdatedAt: entity.UpdatedAt, + } +} diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index e3694c8f..ceda633c 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -326,7 +326,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ) // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ @@ -371,6 +371,7 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { ProviderKey: issuer, ProviderSubject: subject, }, + TargetUserID: &user.ID, ResolvedEmail: email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, @@ -399,7 +400,9 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } type completeOIDCOAuthRequest struct { - InvitationCode string `json:"invitation_code" binding:"required"` + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } // CompleteOIDCOAuthRegistration completes a pending OAuth registration by validating @@ -447,11 +450,23 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { return } - tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { clearOAuthPendingSessionCookie(c, secureCookie) clearOAuthPendingBrowserCookie(c, secureCookie) diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index c389db51..9107e13a 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -1,6 +1,7 @@ package handler import ( + "bytes" "context" "crypto/rand" "crypto/rsa" @@ -12,7 +13,13 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) @@ -123,3 +130,80 @@ func buildRSAJWK(kid string, pub *rsa.PublicKey) oidcJWK { E: e, } } + +func TestCompleteOIDCOAuthRegistrationAppliesPendingAdoptionDecision(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("oidc-complete-session"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example.com"). + SetProviderSubject("oidc-subject-1"). + SetResolvedEmail("93a310f4c1944c5bbd2e246df1f76485@oidc-connect.invalid"). + SetBrowserSessionKey("oidc-browser"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "issuer": "https://issuer.example.com", + "suggested_display_name": "OIDC Display", + "suggested_avatar_url": "https://cdn.example/oidc.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = service.NewAuthPendingIdentityService(client).UpsertAdoptionDecision(ctx, service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptAvatar: true, + }) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/complete-registration", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("oidc-browser")}) + c.Request = req + + handler.CompleteOIDCOAuthRegistration(c) + + require.Equal(t, http.StatusOK, recorder.Code) + responseData := decodeJSONBody(t, recorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ(session.ResolvedEmail)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "OIDC Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("oidc-subject-1"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "OIDC Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/oidc.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(session.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go new file mode 100644 index 00000000..867a77a1 --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -0,0 +1,618 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +const ( + wechatOAuthCookiePath = "/api/v1/auth/oauth/wechat" + wechatOAuthCookieMaxAgeSec = 10 * 60 + wechatOAuthStateCookieName = "wechat_oauth_state" + wechatOAuthRedirectCookieName = "wechat_oauth_redirect" + wechatOAuthIntentCookieName = "wechat_oauth_intent" + wechatOAuthModeCookieName = "wechat_oauth_mode" + wechatOAuthDefaultRedirectTo = "/dashboard" + wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" + wechatOAuthProviderKey = "wechat-main" + + wechatOAuthIntentLogin = "login" + wechatOAuthIntentBind = "bind_current_user" + wechatOAuthIntentAdoptEmail = "adopt_existing_user_by_email" +) + +var ( + wechatOAuthAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token" + wechatOAuthUserInfoURL = "https://api.weixin.qq.com/sns/userinfo" +) + +type wechatOAuthConfig struct { + mode string + appID string + appSecret string + authorizeURL string + scope string + redirectURI string + frontendCallback string +} + +type wechatOAuthTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + OpenID string `json:"openid"` + Scope string `json:"scope"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +type wechatOAuthUserInfoResponse struct { + OpenID string `json:"openid"` + Nickname string `json:"nickname"` + HeadImgURL string `json:"headimgurl"` + UnionID string `json:"unionid"` + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} + +// WeChatOAuthStart starts the WeChat OAuth login flow and stores the short-lived +// browser cookies required by the rebuild pending-auth bridge. +func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), c.Query("mode"), c) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + + intent := normalizeWeChatOAuthIntent(c.Query("intent")) + secureCookie := isRequestHTTPS(c) + wechatSetCookie(c, wechatOAuthStateCookieName, encodeCookieValue(state), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthRedirectCookieName, encodeCookieValue(redirectTo), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthIntentCookieName, encodeCookieValue(intent), wechatOAuthCookieMaxAgeSec, secureCookie) + wechatSetCookie(c, wechatOAuthModeCookieName, encodeCookieValue(cfg.mode), wechatOAuthCookieMaxAgeSec, secureCookie) + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + + authURL, err := buildWeChatAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid, +// and stores the result in the unified pending-auth flow. +func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { + frontendCallback := wechatOAuthFrontendCallback() + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + wechatClearCookie(c, wechatOAuthStateCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthRedirectCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthIntentCookieName, secureCookie) + wechatClearCookie(c, wechatOAuthModeCookieName, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, wechatOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, wechatOAuthRedirectCookieName) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = wechatOAuthDefaultRedirectTo + } + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing oauth browser session", "") + return + } + + intent, _ := readCookieDecoded(c, wechatOAuthIntentCookieName) + mode, err := readCookieDecoded(c, wechatOAuthModeCookieName) + if err != nil || strings.TrimSpace(mode) == "" { + redirectOAuthError(c, frontendCallback, "invalid_state", "missing oauth mode", "") + return + } + + cfg, err := h.getWeChatOAuthConfig(c.Request.Context(), mode, c) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + tokenResp, userInfo, err := fetchWeChatOAuthIdentity(c.Request.Context(), cfg, code) + if err != nil { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_identity_fetch_failed", singleLine(err.Error())) + return + } + + unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) + openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) + providerSubject := firstNonEmpty(unionid, openid) + if providerSubject == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "") + return + } + + username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) + email := wechatSyntheticEmail(providerSubject) + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": providerSubject, + "openid": openid, + "unionid": unionid, + "mode": cfg.mode, + "suggested_display_name": strings.TrimSpace(userInfo.Nickname), + "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), + } + + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") + if err != nil { + if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if err := h.createWeChatPendingSession(c, normalizeWeChatOAuthIntent(intent), providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +type completeWeChatOAuthRequest struct { + InvitationCode string `json:"invitation_code" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +// CompleteWeChatOAuthRegistration completes a pending WeChat OAuth registration by +// validating the invitation code and consuming the current pending browser session. +// POST /api/v1/auth/oauth/wechat/complete-registration +func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { + var req completeWeChatOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode) + if err != nil { + response.ErrorFrom(c, err) + return + } + decision, err := h.upsertPendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), session, decision, &user.ID); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), sessionToken, browserSessionKey); err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) createWeChatPendingSession( + c *gin.Context, + intent string, + providerSubject string, + email string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + tokenPair *service.TokenPair, + authErr error, +) error { + completionResponse := map[string]any{ + "redirect": redirectTo, + } + if authErr != nil { + if errors.Is(authErr, service.ErrOAuthInvitationRequired) { + completionResponse["error"] = "invitation_required" + } else { + return authErr + } + } else if tokenPair != nil { + completionResponse["access_token"] = tokenPair.AccessToken + completionResponse["refresh_token"] = tokenPair.RefreshToken + completionResponse["expires_in"] = tokenPair.ExpiresIn + completionResponse["token_type"] = "Bearer" + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: intent, + Identity: service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + }, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { + mode, err := resolveWeChatOAuthMode(rawMode, c) + if err != nil { + return wechatOAuthConfig{}, err + } + + apiBaseURL := "" + if h != nil && h.settingSvc != nil { + settings, err := h.settingSvc.GetAllSettings(ctx) + if err == nil && settings != nil { + apiBaseURL = strings.TrimSpace(settings.APIBaseURL) + } + } + + cfg := wechatOAuthConfig{ + mode: mode, + redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"), + frontendCallback: wechatOAuthFrontendCallback(), + } + + switch mode { + case "mp": + cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) + cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) + cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize" + cfg.scope = "snsapi_userinfo" + default: + cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) + cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) + cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect" + cfg.scope = "snsapi_login" + } + + if cfg.appID == "" || cfg.appSecret == "" { + return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if strings.TrimSpace(cfg.redirectURI) == "" { + return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + + return cfg, nil +} + +func wechatOAuthFrontendCallback() string { + return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB) +} + +func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) { + mode := strings.ToLower(strings.TrimSpace(rawMode)) + if mode == "" { + if isWeChatBrowserRequest(c) { + return "mp", nil + } + return "open", nil + } + if mode != "open" && mode != "mp" { + return "", infraerrors.BadRequest("INVALID_MODE", "wechat oauth mode must be open or mp") + } + return mode, nil +} + +func isWeChatBrowserRequest(c *gin.Context) bool { + if c == nil || c.Request == nil { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(c.GetHeader("User-Agent"))), "micromessenger") +} + +func normalizeWeChatOAuthIntent(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "", "login": + return wechatOAuthIntentLogin + case "bind", "bind_current_user": + return wechatOAuthIntentBind + case "adopt", "adopt_existing_user_by_email": + return wechatOAuthIntentAdoptEmail + default: + return wechatOAuthIntentLogin + } +} + +func buildWeChatAuthorizeURL(cfg wechatOAuthConfig, state string) (string, error) { + u, err := url.Parse(cfg.authorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize url: %w", err) + } + query := u.Query() + query.Set("appid", cfg.appID) + query.Set("redirect_uri", cfg.redirectURI) + query.Set("response_type", "code") + query.Set("scope", cfg.scope) + query.Set("state", state) + u.RawQuery = query.Encode() + u.Fragment = "wechat_redirect" + return u.String(), nil +} + +func resolveWeChatOAuthAbsoluteURL(apiBaseURL string, c *gin.Context, callbackPath string) string { + callbackPath = strings.TrimSpace(callbackPath) + if callbackPath == "" { + return "" + } + + if raw := strings.TrimSpace(apiBaseURL); raw != "" { + if parsed, err := url.Parse(raw); err == nil && parsed.Scheme != "" && parsed.Host != "" { + basePath := strings.TrimRight(parsed.EscapedPath(), "/") + targetPath := callbackPath + if basePath != "" && strings.HasSuffix(basePath, "/api/v1") && strings.HasPrefix(callbackPath, "/api/v1") { + targetPath = basePath + strings.TrimPrefix(callbackPath, "/api/v1") + } else if basePath != "" { + targetPath = basePath + callbackPath + } + return parsed.Scheme + "://" + parsed.Host + targetPath + } + } + + if c == nil || c.Request == nil { + return "" + } + scheme := "http" + if isRequestHTTPS(c) { + scheme = "https" + } + host := strings.TrimSpace(c.Request.Host) + if forwardedHost := strings.TrimSpace(c.GetHeader("X-Forwarded-Host")); forwardedHost != "" { + host = forwardedHost + } + if host == "" { + return "" + } + return scheme + "://" + host + callbackPath +} + +func fetchWeChatOAuthIdentity(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, *wechatOAuthUserInfoResponse, error) { + tokenResp, err := exchangeWeChatOAuthCode(ctx, cfg, code) + if err != nil { + return nil, nil, err + } + userInfo, err := fetchWeChatUserInfo(ctx, tokenResp) + if err != nil { + return nil, nil, err + } + return tokenResp, userInfo, nil +} + +func exchangeWeChatOAuthCode(ctx context.Context, cfg wechatOAuthConfig, code string) (*wechatOAuthTokenResponse, error) { + endpoint, err := url.Parse(wechatOAuthAccessTokenURL) + if err != nil { + return nil, fmt.Errorf("parse wechat access token url: %w", err) + } + + query := endpoint.Query() + query.Set("appid", cfg.appID) + query.Set("secret", cfg.appSecret) + query.Set("code", strings.TrimSpace(code)) + query.Set("grant_type", "authorization_code") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat access token request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat access token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat access token response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat access token status=%d", resp.StatusCode) + } + + var tokenResp wechatOAuthTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("decode wechat access token response: %w", err) + } + if tokenResp.ErrCode != 0 { + return nil, fmt.Errorf("wechat access token error=%d %s", tokenResp.ErrCode, strings.TrimSpace(tokenResp.ErrMsg)) + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, fmt.Errorf("wechat access token missing access_token") + } + return &tokenResp, nil +} + +func fetchWeChatUserInfo(ctx context.Context, tokenResp *wechatOAuthTokenResponse) (*wechatOAuthUserInfoResponse, error) { + if tokenResp == nil { + return nil, fmt.Errorf("wechat token response is nil") + } + + endpoint, err := url.Parse(wechatOAuthUserInfoURL) + if err != nil { + return nil, fmt.Errorf("parse wechat userinfo url: %w", err) + } + query := endpoint.Query() + query.Set("access_token", strings.TrimSpace(tokenResp.AccessToken)) + query.Set("openid", strings.TrimSpace(tokenResp.OpenID)) + query.Set("lang", "zh_CN") + endpoint.RawQuery = query.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) + if err != nil { + return nil, fmt.Errorf("build wechat userinfo request: %w", err) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request wechat userinfo: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read wechat userinfo response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("wechat userinfo status=%d", resp.StatusCode) + } + + var userInfo wechatOAuthUserInfoResponse + if err := json.Unmarshal(body, &userInfo); err != nil { + return nil, fmt.Errorf("decode wechat userinfo response: %w", err) + } + if userInfo.ErrCode != 0 { + return nil, fmt.Errorf("wechat userinfo error=%d %s", userInfo.ErrCode, strings.TrimSpace(userInfo.ErrMsg)) + } + return &userInfo, nil +} + +func wechatSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "wechat-" + subject + service.WeChatConnectSyntheticEmailDomain +} + +func wechatFallbackUsername(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "wechat_user" + } + return "wechat_" + truncateFragmentValue(subject) +} + +func wechatSetCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: wechatOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func wechatClearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: wechatOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go new file mode 100644 index 00000000..1a765dcc --- /dev/null +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -0,0 +1,411 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "database/sql" + "encoding/base64" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) + c.Request.Host = "api.example.com" + + handler := &AuthHandler{} + handler.WeChatOAuthStart(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.NotEmpty(t, location) + require.Contains(t, location, "open.weixin.qq.com") + require.Contains(t, location, "appid=wx-open-app") + require.Contains(t, location, "scope=snsapi_login") + + cookies := recorder.Result().Cookies() + require.NotEmpty(t, findCookie(cookies, wechatOAuthStateCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthRedirectCookieName)) + require.NotEmpty(t, findCookie(cookies, wechatOAuthModeCookieName)) + require.NotEmpty(t, findCookie(cookies, oauthPendingBrowserCookieName)) +} + +func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + ctx := context.Background() + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "wechat", session.ProviderType) + require.Equal(t, "wechat-main", session.ProviderKey) + require.Equal(t, "union-456", session.ProviderSubject) + require.Equal(t, "wechat-union-456@wechat-connect.invalid", session.ResolvedEmail) + require.Equal(t, "WeChat Nick", session.UpstreamIdentityClaims["suggested_display_name"]) + require.Equal(t, "https://cdn.example/avatar.png", session.UpstreamIdentityClaims["suggested_avatar_url"]) + require.Equal(t, "union-456", session.UpstreamIdentityClaims["unionid"]) + require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) +} + +func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"WeChat Display","headimgurl":"https://cdn.example/wechat.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, true) + defer client.Close() + + ctx := context.Background() + redeemRepo := repository.NewRedeemCodeRepository(client) + require.NoError(t, redeemRepo.Create(ctx, &service.RedeemCode{ + Code: "invite-1", + Type: service.RedeemTypeInvitation, + Status: service.StatusUnused, + })) + + callbackRecorder := httptest.NewRecorder() + callbackCtx, _ := gin.CreateTestContext(callbackRecorder) + callbackReq := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + callbackReq.Host = "api.example.com" + callbackReq.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + callbackReq.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + callbackReq.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + callbackReq.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + callbackCtx.Request = callbackReq + + handler.WeChatOAuthCallback(callbackCtx) + + require.Equal(t, http.StatusFound, callbackRecorder.Code) + require.Equal(t, "/auth/wechat/callback", callbackRecorder.Header().Get("Location")) + + sessionCookie := findCookie(callbackRecorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + sessionToken := decodeCookieValueForTest(t, sessionCookie.Value) + + pendingSession, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "invitation_required", pendingSession.LocalFlowState[oauthCompletionResponseKey].(map[string]any)["error"]) + + body := bytes.NewBufferString(`{"invitation_code":"invite-1","adopt_display_name":true,"adopt_avatar":true}`) + completeRecorder := httptest.NewRecorder() + completeCtx, _ := gin.CreateTestContext(completeRecorder) + completeReq := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/wechat/complete-registration", body) + completeReq.Header.Set("Content-Type", "application/json") + completeReq.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(sessionToken)}) + completeReq.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("browser-123")}) + completeCtx.Request = completeReq + + handler.CompleteWeChatOAuthRegistration(completeCtx) + + require.Equal(t, http.StatusOK, completeRecorder.Code) + responseData := decodeJSONBody(t, completeRecorder) + require.NotEmpty(t, responseData["access_token"]) + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEQ("wechat-union-456@wechat-connect.invalid")). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "WeChat Display", userEntity.Username) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, userEntity.ID, identity.UserID) + require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) + require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + + decision, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) + require.True(t, decision.AdoptDisplayName) + require.True(t, decision.AdoptAvatar) + + consumed, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.IDEQ(pendingSession.ID)). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) +} + +func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + + userRepo := &oauthPendingFlowUserRepo{client: client} + redeemRepo := repository.NewRedeemCodeRepository(client) + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + }) + + authSvc := service.NewAuthService( + client, + userRepo, + redeemRepo, + &wechatOAuthRefreshTokenCacheStub{}, + &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-secret", + ExpireHour: 1, + AccessTokenExpireMinutes: 60, + RefreshTokenExpireDays: 7, + }, + Default: config.DefaultConfig{ + UserBalance: 0, + UserConcurrency: 1, + }, + }, + settingSvc, + nil, + nil, + nil, + nil, + nil, + ) + + return &AuthHandler{ + authService: authSvc, + settingSvc: settingSvc, + }, client +} + +func encodedCookie(name, value string) *http.Cookie { + return &http.Cookie{ + Name: name, + Value: encodeCookieValue(value), + Path: "/", + } +} + +func findCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +func decodeCookieValueForTest(t *testing.T, value string) string { + t.Helper() + raw, err := base64.RawURLEncoding.DecodeString(value) + require.NoError(t, err) + return string(raw) +} + +type wechatOAuthSettingRepoStub struct { + values map[string]string +} + +func (s *wechatOAuthSettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + return nil, service.ErrSettingNotFound +} + +func (s *wechatOAuthSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *wechatOAuthSettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + result[key] = value + } + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (s *wechatOAuthSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + result := make(map[string]string, len(s.values)) + for key, value := range s.values { + result[key] = value + } + return result, nil +} + +func (s *wechatOAuthSettingRepoStub) Delete(context.Context, string) error { + return nil +} + +type wechatOAuthRefreshTokenCacheStub struct{} + +func (s *wechatOAuthRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetRefreshToken(context.Context, string) (*service.RefreshTokenData, error) { + return nil, service.ErrRefreshTokenNotFound +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *wechatOAuthRefreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 3659e79b..f44b3e3b 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -189,6 +189,7 @@ type PublicSettings struct { CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` OIDCOAuthProviderName string `json:"oidc_oauth_provider_name"` SoraClientEnabled bool `json:"sora_client_enabled"` diff --git a/backend/internal/handler/payment_webhook_handler.go b/backend/internal/handler/payment_webhook_handler.go index 8a83bfeb..9fdefa93 100644 --- a/backend/internal/handler/payment_webhook_handler.go +++ b/backend/internal/handler/payment_webhook_handler.go @@ -120,7 +120,7 @@ func (h *PaymentWebhookHandler) handleNotify(c *gin.Context, providerKey string) // This allows looking up the correct provider instance before verification. func extractOutTradeNo(rawBody, providerKey string) string { switch providerKey { - case payment.TypeEasyPay: + case payment.TypeEasyPay, payment.TypeAlipay: values, err := url.ParseQuery(rawBody) if err == nil { return values.Get("out_trade_no") diff --git a/backend/internal/handler/payment_webhook_handler_test.go b/backend/internal/handler/payment_webhook_handler_test.go index bdef1766..6f448131 100644 --- a/backend/internal/handler/payment_webhook_handler_test.go +++ b/backend/internal/handler/payment_webhook_handler_test.go @@ -97,3 +97,37 @@ func TestWebhookConstants(t *testing.T) { assert.Equal(t, 200, webhookLogTruncateLen) }) } + +func TestExtractOutTradeNo(t *testing.T) { + tests := []struct { + name string + providerKey string + rawBody string + want string + }{ + { + name: "easypay query payload", + providerKey: "easypay", + rawBody: "out_trade_no=sub2_123&trade_status=TRADE_SUCCESS", + want: "sub2_123", + }, + { + name: "alipay query payload", + providerKey: "alipay", + rawBody: "notify_time=2026-04-20+12%3A00%3A00&out_trade_no=sub2_456", + want: "sub2_456", + }, + { + name: "unknown provider", + providerKey: "wxpay", + rawBody: "{}", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, extractOutTradeNo(tt.rawBody, tt.providerKey)) + }) + } +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 1717b7a1..c7bc3e2a 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -56,6 +56,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, OIDCOAuthProviderName: settings.OIDCOAuthProviderName, BackendModeEnabled: settings.BackendModeEnabled, diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 2535ea5e..904341d0 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -34,10 +34,16 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type userProfileResponse struct { + dto.User + AvatarURL string `json:"avatar_url,omitempty"` +} + // GetProfile handles getting user profile // GET /api/v1/users/me func (h *UserHandler) GetProfile(c *gin.Context) { @@ -47,13 +53,13 @@ func (h *UserHandler) GetProfile(c *gin.Context) { return } - userData, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + userData, err := h.userService.GetProfile(c.Request.Context(), subject.UserID) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, dto.UserFromService(userData)) + response.Success(c, userProfileResponseFromService(userData)) } // ChangePassword handles changing user password @@ -101,6 +107,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { svcReq := service.UpdateProfileRequest{ Username: req.Username, + AvatarURL: req.AvatarURL, BalanceNotifyEnabled: req.BalanceNotifyEnabled, BalanceNotifyThreshold: req.BalanceNotifyThreshold, } @@ -110,7 +117,7 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // SendNotifyEmailCodeRequest represents the request to send notify email verification code @@ -176,7 +183,7 @@ func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // RemoveNotifyEmailRequest represents the request to remove a notify email @@ -212,7 +219,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) } // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state @@ -248,5 +255,16 @@ func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { return } - response.Success(c, dto.UserFromService(updatedUser)) + response.Success(c, userProfileResponseFromService(updatedUser)) +} + +func userProfileResponseFromService(user *service.User) userProfileResponse { + base := dto.UserFromService(user) + if base == nil { + return userProfileResponse{} + } + return userProfileResponse{ + User: *base, + AvatarURL: user.AvatarURL, + } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go new file mode 100644 index 00000000..1973f59e --- /dev/null +++ b/backend/internal/handler/user_handler_test.go @@ -0,0 +1,136 @@ +//go:build unit + +package handler + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userHandlerRepoStub struct { + user *service.User +} + +func (s *userHandlerRepoStub) Create(context.Context, *service.User) error { return nil } +func (s *userHandlerRepoStub) GetByID(context.Context, int64) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetByEmail(context.Context, string) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) GetFirstAdmin(context.Context) (*service.User, error) { + cloned := *s.user + return &cloned, nil +} +func (s *userHandlerRepoStub) Update(_ context.Context, user *service.User) error { + cloned := *user + s.user = &cloned + return nil +} +func (s *userHandlerRepoStub) Delete(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { + if s.user == nil || s.user.AvatarURL == "" { + return nil, nil + } + return &service.UserAvatar{ + StorageProvider: s.user.AvatarSource, + URL: s.user.AvatarURL, + ContentType: s.user.AvatarMIME, + ByteSize: s.user.AvatarByteSize, + SHA256: s.user.AvatarSHA256, + }, nil +} +func (s *userHandlerRepoStub) UpsertUserAvatar(_ context.Context, _ int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + s.user.AvatarURL = input.URL + s.user.AvatarSource = input.StorageProvider + s.user.AvatarMIME = input.ContentType + s.user.AvatarByteSize = input.ByteSize + s.user.AvatarSHA256 = input.SHA256 + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (s *userHandlerRepoStub) DeleteUserAvatar(context.Context, int64) error { + s.user.AvatarURL = "" + s.user.AvatarSource = "" + s.user.AvatarMIME = "" + s.user.AvatarByteSize = 0 + s.user.AvatarSHA256 = "" + return nil +} +func (s *userHandlerRepoStub) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } +func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + return nil +} +func (s *userHandlerRepoStub) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (s *userHandlerRepoStub) EnableTotp(context.Context, int64) error { return nil } +func (s *userHandlerRepoStub) DisableTotp(context.Context, int64) error { return nil } + +func TestUserHandlerUpdateProfileReturnsAvatarURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 11, + Email: "handler-avatar@example.com", + Username: "handler-avatar", + Role: service.RoleUser, + Status: service.StatusActive, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + + body := []byte(`{"avatar_url":"https://cdn.example.com/avatar.png"}`) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/user", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 11}) + + handler.UpdateProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + AvatarURL string `json:"avatar_url"` + Username string `json:"username"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, "https://cdn.example.com/avatar.png", resp.Data.AvatarURL) + require.Equal(t, "handler-avatar", resp.Data.Username) +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 38ea9bde..36d80309 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -149,6 +149,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldBalanceNotifyThreshold, user.FieldBalanceNotifyExtraEmails, user.FieldTotalRecharged, + user.FieldSignupSource, + user.FieldLastLoginAt, + user.FieldLastActiveAt, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -656,6 +659,9 @@ func userEntityToService(u *dbent.User) *service.User { Balance: u.Balance, Concurrency: u.Concurrency, Status: u.Status, + SignupSource: u.SignupSource, + LastLoginAt: u.LastLoginAt, + LastActiveAt: u.LastActiveAt, TotpSecretEncrypted: u.TotpSecretEncrypted, TotpEnabled: u.TotpEnabled, TotpEnabledAt: u.TotpEnabledAt, diff --git a/backend/internal/repository/auth_identity_migration_report.go b/backend/internal/repository/auth_identity_migration_report.go new file mode 100644 index 00000000..70f298c1 --- /dev/null +++ b/backend/internal/repository/auth_identity_migration_report.go @@ -0,0 +1,148 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" +) + +type AuthIdentityMigrationReport struct { + ID int64 + ReportType string + ReportKey string + Details map[string]any + CreatedAt time.Time +} + +type AuthIdentityMigrationReportQuery struct { + ReportType string + Limit int + Offset int +} + +type AuthIdentityMigrationReportSummary struct { + Total int64 + ByType map[string]int64 +} + +func (r *userRepository) ListAuthIdentityMigrationReports(ctx context.Context, query AuthIdentityMigrationReportQuery) ([]AuthIdentityMigrationReport, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + limit := query.Limit + if limit <= 0 { + limit = 100 + } + rows, err := exec.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE ($1 = '' OR report_type = $1) +ORDER BY created_at DESC, id DESC +LIMIT $2 OFFSET $3`, + strings.TrimSpace(query.ReportType), + limit, + query.Offset, + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + reports := make([]AuthIdentityMigrationReport, 0) + for rows.Next() { + report, scanErr := scanAuthIdentityMigrationReport(rows) + if scanErr != nil { + return nil, scanErr + } + reports = append(reports, report) + } + if err := rows.Err(); err != nil { + return nil, err + } + return reports, nil +} + +func (r *userRepository) GetAuthIdentityMigrationReport(ctx context.Context, reportType, reportKey string) (*AuthIdentityMigrationReport, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + rows, err := exec.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE report_type = $1 AND report_key = $2 +LIMIT 1`, + strings.TrimSpace(reportType), + strings.TrimSpace(reportKey), + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, sql.ErrNoRows + } + report, err := scanAuthIdentityMigrationReport(rows) + if err != nil { + return nil, err + } + return &report, rows.Err() +} + +func (r *userRepository) SummarizeAuthIdentityMigrationReports(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + + rows, err := exec.QueryContext(ctx, ` +SELECT report_type, COUNT(*) +FROM auth_identity_migration_reports +GROUP BY report_type +ORDER BY report_type ASC`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + summary := &AuthIdentityMigrationReportSummary{ + ByType: make(map[string]int64), + } + for rows.Next() { + var reportType string + var count int64 + if err := rows.Scan(&reportType, &count); err != nil { + return nil, err + } + summary.ByType[reportType] = count + summary.Total += count + } + if err := rows.Err(); err != nil { + return nil, err + } + return summary, nil +} + +func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { + var ( + report AuthIdentityMigrationReport + details []byte + ) + if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil { + return AuthIdentityMigrationReport{}, err + } + report.Details = map[string]any{} + if len(details) > 0 { + if err := json.Unmarshal(details, &report.Details); err != nil { + return AuthIdentityMigrationReport{}, err + } + } + return report, nil +} diff --git a/backend/internal/repository/user_profile_identity_repo.go b/backend/internal/repository/user_profile_identity_repo.go new file mode 100644 index 00000000..4ecae4a4 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo.go @@ -0,0 +1,544 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "time" + "unsafe" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +var ( + ErrAuthIdentityOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_OWNERSHIP_CONFLICT", + "auth identity already belongs to another user", + ) + ErrAuthIdentityChannelOwnershipConflict = infraerrors.Conflict( + "AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", + "auth identity channel already belongs to another user", + ) +) + +type ProviderGrantReason string + +const ( + ProviderGrantReasonSignup ProviderGrantReason = "signup" + ProviderGrantReasonFirstBind ProviderGrantReason = "first_bind" +) + +type AuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type AuthIdentityChannelKey struct { + ProviderType string + ProviderKey string + Channel string + ChannelAppID string + ChannelSubject string +} + +type CreateAuthIdentityInput struct { + UserID int64 + Canonical AuthIdentityKey + Channel *AuthIdentityChannelKey + Issuer *string + VerifiedAt *time.Time + Metadata map[string]any + ChannelMetadata map[string]any +} + +type BindAuthIdentityInput = CreateAuthIdentityInput + +type CreateAuthIdentityResult struct { + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +func (r *CreateAuthIdentityResult) IdentityRef() AuthIdentityKey { + if r == nil || r.Identity == nil { + return AuthIdentityKey{} + } + return AuthIdentityKey{ + ProviderType: r.Identity.ProviderType, + ProviderKey: r.Identity.ProviderKey, + ProviderSubject: r.Identity.ProviderSubject, + } +} + +func (r *CreateAuthIdentityResult) ChannelRef() *AuthIdentityChannelKey { + if r == nil || r.Channel == nil { + return nil + } + return &AuthIdentityChannelKey{ + ProviderType: r.Channel.ProviderType, + ProviderKey: r.Channel.ProviderKey, + Channel: r.Channel.Channel, + ChannelAppID: r.Channel.ChannelAppID, + ChannelSubject: r.Channel.ChannelSubject, + } +} + +type UserAuthIdentityLookup struct { + User *dbent.User + Identity *dbent.AuthIdentity + Channel *dbent.AuthIdentityChannel +} + +type ProviderGrantRecordInput struct { + UserID int64 + ProviderType string + GrantReason ProviderGrantReason +} + +type IdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type sqlQueryExecutor interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +func (r *userRepository) WithUserProfileIdentityTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if dbent.TxFromContext(ctx) != nil { + return fn(ctx) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx); err != nil { + return err + } + return tx.Commit() +} + +func (r *userRepository) CreateAuthIdentity(ctx context.Context, input CreateAuthIdentityInput) (*CreateAuthIdentityResult, error) { + client := clientFromContext(ctx, r.client) + + create := client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(input.Canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt) + + identity, err := create.Save(ctx) + if err != nil { + return nil, err + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(ctx) + if err != nil { + return nil, err + } + } + + return &CreateAuthIdentityResult{Identity: identity, Channel: channel}, nil +} + +func (r *userRepository) GetUserByCanonicalIdentity(ctx context.Context, key AuthIdentityKey) (*UserAuthIdentityLookup, error) { + identity, err := clientFromContext(ctx, r.client).AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(key.ProviderSubject)), + ). + WithUser(). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: identity.Edges.User, + Identity: identity, + }, nil +} + +func (r *userRepository) GetUserByChannelIdentity(ctx context.Context, key AuthIdentityChannelKey) (*UserAuthIdentityLookup, error) { + channel, err := clientFromContext(ctx, r.client).AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(key.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(key.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(key.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(key.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(key.ChannelSubject)), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil { + return nil, err + } + + return &UserAuthIdentityLookup{ + User: channel.Edges.Identity.Edges.User, + Identity: channel.Edges.Identity, + Channel: channel, + }, nil +} + +func (r *userRepository) BindAuthIdentityToUser(ctx context.Context, input BindAuthIdentityInput) (*CreateAuthIdentityResult, error) { + var result *CreateAuthIdentityResult + err := r.WithUserProfileIdentityTx(ctx, func(txCtx context.Context) error { + client := clientFromContext(txCtx, r.client) + canonical := input.Canonical + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(canonical.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(canonical.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(canonical.ProviderSubject)), + ). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if identity != nil && identity.UserID != input.UserID { + return ErrAuthIdentityOwnershipConflict + } + if identity == nil { + identity, err = client.AuthIdentity.Create(). + SetUserID(input.UserID). + SetProviderType(strings.TrimSpace(canonical.ProviderType)). + SetProviderKey(strings.TrimSpace(canonical.ProviderKey)). + SetProviderSubject(strings.TrimSpace(canonical.ProviderSubject)). + SetMetadata(copyMetadata(input.Metadata)). + SetNillableIssuer(input.Issuer). + SetNillableVerifiedAt(input.VerifiedAt). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentity.UpdateOneID(identity.ID) + if input.Metadata != nil { + update = update.SetMetadata(copyMetadata(input.Metadata)) + } + if input.Issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*input.Issuer)) + } + if input.VerifiedAt != nil { + update = update.SetVerifiedAt(*input.VerifiedAt) + } + identity, err = update.Save(txCtx) + if err != nil { + return err + } + } + + var channel *dbent.AuthIdentityChannel + if input.Channel != nil { + channel, err = client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(input.Channel.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(input.Channel.ProviderKey)), + authidentitychannel.ChannelEQ(strings.TrimSpace(input.Channel.Channel)), + authidentitychannel.ChannelAppIDEQ(strings.TrimSpace(input.Channel.ChannelAppID)), + authidentitychannel.ChannelSubjectEQ(strings.TrimSpace(input.Channel.ChannelSubject)), + ). + WithIdentity(). + Only(txCtx) + if err != nil && !dbent.IsNotFound(err) { + return err + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != input.UserID { + return ErrAuthIdentityChannelOwnershipConflict + } + if channel == nil { + channel, err = client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(strings.TrimSpace(input.Channel.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Channel.ProviderKey)). + SetChannel(strings.TrimSpace(input.Channel.Channel)). + SetChannelAppID(strings.TrimSpace(input.Channel.ChannelAppID)). + SetChannelSubject(strings.TrimSpace(input.Channel.ChannelSubject)). + SetMetadata(copyMetadata(input.ChannelMetadata)). + Save(txCtx) + if err != nil { + return err + } + } else { + update := client.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID) + if input.ChannelMetadata != nil { + update = update.SetMetadata(copyMetadata(input.ChannelMetadata)) + } + channel, err = update.Save(txCtx) + if err != nil { + return err + } + } + } + + result = &CreateAuthIdentityResult{Identity: identity, Channel: channel} + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +func (r *userRepository) RecordProviderGrant(ctx context.Context, input ProviderGrantRecordInput) (bool, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return false, fmt.Errorf("sql executor is not configured") + } + + result, err := exec.ExecContext(ctx, ` +INSERT INTO user_provider_default_grants (user_id, provider_type, grant_reason) +VALUES ($1, $2, $3) +ON CONFLICT (user_id, provider_type, grant_reason) DO NOTHING`, + input.UserID, + strings.TrimSpace(input.ProviderType), + string(input.GrantReason), + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *userRepository) UpsertIdentityAdoptionDecision(ctx context.Context, input IdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + client := clientFromContext(ctx, r.client) + current, err := client.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + now := time.Now().UTC() + if current == nil { + create := client.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(now) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := client.IdentityAdoptionDecision.UpdateOneID(current.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func (r *userRepository) GetIdentityAdoptionDecisionByPendingAuthSessionID(ctx context.Context, pendingAuthSessionID int64) (*dbent.IdentityAdoptionDecision, error) { + return clientFromContext(ctx, r.client).IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingAuthSessionID)). + Only(ctx) +} + +func (r *userRepository) UpdateUserLastLoginAt(ctx context.Context, userID int64, loginAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastLoginAt(loginAt). + Save(ctx) + return err +} + +func (r *userRepository) UpdateUserLastActiveAt(ctx context.Context, userID int64, activeAt time.Time) error { + _, err := clientFromContext(ctx, r.client).User.UpdateOneID(userID). + SetLastActiveAt(activeAt). + Save(ctx) + return err +} + +func (r *userRepository) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + rows, err := exec.QueryContext(ctx, ` +SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 +FROM user_avatars +WHERE user_id = $1`, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil +} + +func (r *userRepository) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return nil, err + } + + _, err = exec.ExecContext(ctx, ` +INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, NOW()) +ON CONFLICT (user_id) DO UPDATE SET + storage_provider = EXCLUDED.storage_provider, + storage_key = EXCLUDED.storage_key, + url = EXCLUDED.url, + content_type = EXCLUDED.content_type, + byte_size = EXCLUDED.byte_size, + sha256 = EXCLUDED.sha256, + updated_at = NOW()`, + userID, + strings.TrimSpace(input.StorageProvider), + strings.TrimSpace(input.StorageKey), + strings.TrimSpace(input.URL), + strings.TrimSpace(input.ContentType), + input.ByteSize, + strings.TrimSpace(input.SHA256), + ) + if err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: strings.TrimSpace(input.StorageProvider), + StorageKey: strings.TrimSpace(input.StorageKey), + URL: strings.TrimSpace(input.URL), + ContentType: strings.TrimSpace(input.ContentType), + ByteSize: input.ByteSize, + SHA256: strings.TrimSpace(input.SHA256), + }, nil +} + +func (r *userRepository) DeleteUserAvatar(ctx context.Context, userID int64) error { + exec, err := r.userProfileIdentitySQL(ctx) + if err != nil { + return err + } + _, err = exec.ExecContext(ctx, `DELETE FROM user_avatars WHERE user_id = $1`, userID) + return err +} + +func (r *userRepository) attachUserAvatar(ctx context.Context, user *service.User) error { + if user == nil { + return nil + } + + avatar, err := r.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + if avatar == nil { + return nil + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 + return nil +} + +func copyMetadata(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func txAwareSQLExecutor(ctx context.Context, fallback sqlExecutor, client *dbent.Client) sqlQueryExecutor { + if tx := dbent.TxFromContext(ctx); tx != nil { + if exec := sqlExecutorFromEntClient(tx.Client()); exec != nil { + return exec + } + } + if fallback != nil { + return fallback + } + return sqlExecutorFromEntClient(client) +} + +func (r *userRepository) userProfileIdentitySQL(ctx context.Context) (sqlQueryExecutor, error) { + exec := txAwareSQLExecutor(ctx, r.sql, r.client) + if exec == nil { + return nil, fmt.Errorf("sql executor is not configured") + } + return exec, nil +} + +func sqlExecutorFromEntClient(client *dbent.Client) sqlQueryExecutor { + if client == nil { + return nil + } + + clientValue := reflect.ValueOf(client).Elem() + configValue := clientValue.FieldByName("config") + driverValue := configValue.FieldByName("driver") + if !driverValue.IsValid() { + return nil + } + + driver := reflect.NewAt(driverValue.Type(), unsafe.Pointer(driverValue.UnsafeAddr())).Elem().Interface() + exec, ok := driver.(sqlQueryExecutor) + if !ok { + return nil + } + return exec +} diff --git a/backend/internal/repository/user_profile_identity_repo_contract_test.go b/backend/internal/repository/user_profile_identity_repo_contract_test.go new file mode 100644 index 00000000..19022ec1 --- /dev/null +++ b/backend/internal/repository/user_profile_identity_repo_contract_test.go @@ -0,0 +1,428 @@ +//go:build integration + +package repository + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/suite" +) + +type UserProfileIdentityRepoSuite struct { + suite.Suite + ctx context.Context + client *dbent.Client + repo *userRepository +} + +func TestUserProfileIdentityRepoSuite(t *testing.T) { + suite.Run(t, new(UserProfileIdentityRepoSuite)) +} + +func (s *UserProfileIdentityRepoSuite) SetupTest() { + s.ctx = context.Background() + s.client = testEntClient(s.T()) + s.repo = newUserRepositoryWithSQL(s.client, integrationDB) + + _, err := integrationDB.ExecContext(s.ctx, ` +TRUNCATE TABLE + identity_adoption_decisions, + auth_identity_channels, + auth_identities, + pending_auth_sessions, + auth_identity_migration_reports, + user_provider_default_grants, + user_avatars +RESTART IDENTITY`) + s.Require().NoError(err) +} + +func (s *UserProfileIdentityRepoSuite) mustCreateUser(label string) *dbent.User { + s.T().Helper() + + user, err := s.client.User.Create(). + SetEmail(fmt.Sprintf("%s-%d@example.com", label, time.Now().UnixNano())). + SetPasswordHash("test-password-hash"). + SetRole("user"). + SetStatus("active"). + Save(s.ctx) + s.Require().NoError(err) + return user +} + +func (s *UserProfileIdentityRepoSuite) mustCreatePendingAuthSession(key AuthIdentityKey) *dbent.PendingAuthSession { + s.T().Helper() + + session, err := s.client.PendingAuthSession.Create(). + SetSessionToken(fmt.Sprintf("pending-%d", time.Now().UnixNano())). + SetIntent("bind_current_user"). + SetProviderType(key.ProviderType). + SetProviderKey(key.ProviderKey). + SetProviderSubject(key.ProviderSubject). + SetExpiresAt(time.Now().UTC().Add(15 * time.Minute)). + SetUpstreamIdentityClaims(map[string]any{"provider_subject": key.ProviderSubject}). + SetLocalFlowState(map[string]any{"step": "pending"}). + Save(s.ctx) + s.Require().NoError(err) + return session +} + +func (s *UserProfileIdentityRepoSuite) TestCreateAndLookupCanonicalAndChannelIdentity() { + user := s.mustCreateUser("canonical-channel") + + verifiedAt := time.Now().UTC().Truncate(time.Second) + created, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + Channel: "mp", + ChannelAppID: "wx-app", + ChannelSubject: "openid-123", + }, + Issuer: stringPtr("https://issuer.example"), + VerifiedAt: &verifiedAt, + Metadata: map[string]any{"unionid": "union-123"}, + ChannelMetadata: map[string]any{"openid": "openid-123"}, + }) + s.Require().NoError(err) + s.Require().NotNil(created.Identity) + s.Require().NotNil(created.Channel) + + canonical, err := s.repo.GetUserByCanonicalIdentity(s.ctx, created.IdentityRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, canonical.User.ID) + s.Require().Equal(created.Identity.ID, canonical.Identity.ID) + s.Require().Equal("union-123", canonical.Identity.ProviderSubject) + + channel, err := s.repo.GetUserByChannelIdentity(s.ctx, *created.ChannelRef()) + s.Require().NoError(err) + s.Require().Equal(user.ID, channel.User.ID) + s.Require().Equal(created.Identity.ID, channel.Identity.ID) + s.Require().Equal(created.Channel.ID, channel.Channel.ID) +} + +func (s *UserProfileIdentityRepoSuite) TestBindAuthIdentityToUser_IsIdempotentAndRejectsOtherOwners() { + owner := s.mustCreateUser("owner") + other := s.mustCreateUser("other") + + first, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "first"}, + ChannelMetadata: map[string]any{"scope": "read"}, + }) + s.Require().NoError(err) + + second, err := s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: owner.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + Metadata: map[string]any{"username": "second"}, + ChannelMetadata: map[string]any{"scope": "write"}, + }) + s.Require().NoError(err) + s.Require().Equal(first.Identity.ID, second.Identity.ID) + s.Require().Equal(first.Channel.ID, second.Channel.ID) + s.Require().Equal("second", second.Identity.Metadata["username"]) + s.Require().Equal("write", second.Channel.Metadata["scope"]) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityOwnershipConflict) + + _, err = s.repo.BindAuthIdentityToUser(s.ctx, BindAuthIdentityInput{ + UserID: other.ID, + Canonical: AuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-2", + }, + Channel: &AuthIdentityChannelKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + Channel: "oauth", + ChannelAppID: "linuxdo-web", + ChannelSubject: "subject-1", + }, + }) + s.Require().ErrorIs(err, ErrAuthIdentityChannelOwnershipConflict) +} + +func (s *UserProfileIdentityRepoSuite) TestWithUserProfileIdentityTx_RollsBackIdentityAndGrantOnError() { + user := s.mustCreateUser("tx-rollback") + expectedErr := errors.New("rollback") + + err := s.repo.WithUserProfileIdentityTx(s.ctx, func(txCtx context.Context) error { + _, err := s.repo.CreateAuthIdentity(txCtx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }, + }) + s.Require().NoError(err) + + inserted, err := s.repo.RecordProviderGrant(txCtx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "oidc", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + return expectedErr + }) + s.Require().ErrorIs(err, expectedErr) + + _, err = s.repo.GetUserByCanonicalIdentity(s.ctx, AuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-rollback", + }) + s.Require().True(dbent.IsNotFound(err)) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2 AND grant_reason = $3`, + user.ID, + "oidc", + string(ProviderGrantReasonFirstBind), + ).Scan(&count)) + s.Require().Zero(count) +} + +func (s *UserProfileIdentityRepoSuite) TestRecordProviderGrant_IsIdempotentPerReason() { + user := s.mustCreateUser("grant") + + inserted, err := s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonFirstBind, + }) + s.Require().NoError(err) + s.Require().False(inserted) + + inserted, err = s.repo.RecordProviderGrant(s.ctx, ProviderGrantRecordInput{ + UserID: user.ID, + ProviderType: "wechat", + GrantReason: ProviderGrantReasonSignup, + }) + s.Require().NoError(err) + s.Require().True(inserted) + + var count int + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT COUNT(*) +FROM user_provider_default_grants +WHERE user_id = $1 AND provider_type = $2`, + user.ID, + "wechat", + ).Scan(&count)) + s.Require().Equal(2, count) +} + +func (s *UserProfileIdentityRepoSuite) TestUpsertIdentityAdoptionDecision_PersistsAndLinksIdentity() { + user := s.mustCreateUser("adoption") + identity, err := s.repo.CreateAuthIdentity(s.ctx, CreateAuthIdentityInput{ + UserID: user.ID, + Canonical: AuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + s.Require().NoError(err) + + session := s.mustCreatePendingAuthSession(identity.IdentityRef()) + + first, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + s.Require().NoError(err) + s.Require().True(first.AdoptDisplayName) + s.Require().False(first.AdoptAvatar) + s.Require().Nil(first.IdentityID) + + second, err := s.repo.UpsertIdentityAdoptionDecision(s.ctx, IdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.Identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + s.Require().NoError(err) + s.Require().Equal(first.ID, second.ID) + s.Require().NotNil(second.IdentityID) + s.Require().Equal(identity.Identity.ID, *second.IdentityID) + s.Require().True(second.AdoptAvatar) + + loaded, err := s.repo.GetIdentityAdoptionDecisionByPendingAuthSessionID(s.ctx, session.ID) + s.Require().NoError(err) + s.Require().Equal(second.ID, loaded.ID) + s.Require().Equal(identity.Identity.ID, *loaded.IdentityID) +} + +func (s *UserProfileIdentityRepoSuite) TestUserAvatarCRUDAndUserLookup() { + user := s.mustCreateUser("avatar") + + inlineAvatar, err := s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: "data:image/png;base64,QUJD", + ContentType: "image/png", + ByteSize: 3, + SHA256: "902fbdd2b1df0c4f70b4a5d23525e932", + }) + s.Require().NoError(err) + s.Require().Equal("inline", inlineAvatar.StorageProvider) + s.Require().Equal("data:image/png;base64,QUJD", inlineAvatar.URL) + + loadedAvatar, err := s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("image/png", loadedAvatar.ContentType) + s.Require().Equal(3, loadedAvatar.ByteSize) + + _, err = s.repo.UpsertUserAvatar(s.ctx, user.ID, service.UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/avatar.png", + }) + s.Require().NoError(err) + + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().NotNil(loadedAvatar) + s.Require().Equal("remote_url", loadedAvatar.StorageProvider) + s.Require().Equal("https://cdn.example.com/avatar.png", loadedAvatar.URL) + s.Require().Zero(loadedAvatar.ByteSize) + + s.Require().NoError(s.repo.DeleteUserAvatar(s.ctx, user.ID)) + loadedAvatar, err = s.repo.GetUserAvatar(s.ctx, user.ID) + s.Require().NoError(err) + s.Require().Nil(loadedAvatar) +} + +func (s *UserProfileIdentityRepoSuite) TestAuthIdentityMigrationReportHelpers_ListAndSummarize() { + _, err := integrationDB.ExecContext(s.ctx, ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES + ('wechat_openid_only_requires_remediation', 'u-1', '{"user_id":1}'::jsonb, '2026-04-20T10:00:00Z'), + ('wechat_openid_only_requires_remediation', 'u-2', '{"user_id":2}'::jsonb, '2026-04-20T11:00:00Z'), + ('oidc_synthetic_email_requires_manual_recovery', 'u-3', '{"user_id":3}'::jsonb, '2026-04-20T12:00:00Z')`) + s.Require().NoError(err) + + summary, err := s.repo.SummarizeAuthIdentityMigrationReports(s.ctx) + s.Require().NoError(err) + s.Require().Equal(int64(3), summary.Total) + s.Require().Equal(int64(2), summary.ByType["wechat_openid_only_requires_remediation"]) + s.Require().Equal(int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"]) + + reports, err := s.repo.ListAuthIdentityMigrationReports(s.ctx, AuthIdentityMigrationReportQuery{ + ReportType: "wechat_openid_only_requires_remediation", + Limit: 10, + }) + s.Require().NoError(err) + s.Require().Len(reports, 2) + s.Require().Equal("u-2", reports[0].ReportKey) + s.Require().Equal(float64(2), reports[0].Details["user_id"]) + + report, err := s.repo.GetAuthIdentityMigrationReport(s.ctx, "oidc_synthetic_email_requires_manual_recovery", "u-3") + s.Require().NoError(err) + s.Require().Equal("u-3", report.ReportKey) + s.Require().Equal(float64(3), report.Details["user_id"]) +} + +func (s *UserProfileIdentityRepoSuite) TestUpdateUserLastLoginAndActiveAt_UsesDedicatedColumns() { + user := s.mustCreateUser("activity") + loginAt := time.Date(2026, 4, 20, 8, 0, 0, 0, time.UTC) + activeAt := loginAt.Add(5 * time.Minute) + + s.Require().NoError(s.repo.UpdateUserLastLoginAt(s.ctx, user.ID, loginAt)) + s.Require().NoError(s.repo.UpdateUserLastActiveAt(s.ctx, user.ID, activeAt)) + + var storedLoginAt sqlNullTime + var storedActiveAt sqlNullTime + s.Require().NoError(integrationDB.QueryRowContext(s.ctx, ` +SELECT last_login_at, last_active_at +FROM users +WHERE id = $1`, + user.ID, + ).Scan(&storedLoginAt, &storedActiveAt)) + s.Require().True(storedLoginAt.Valid) + s.Require().True(storedActiveAt.Valid) + s.Require().True(storedLoginAt.Time.Equal(loginAt)) + s.Require().True(storedActiveAt.Time.Equal(activeAt)) +} + +type sqlNullTime struct { + Time time.Time + Valid bool +} + +func (t *sqlNullTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + t.Time = v + t.Valid = true + return nil + case nil: + t.Time = time.Time{} + t.Valid = false + return nil + default: + return fmt.Errorf("unsupported scan type %T", value) + } +} + +func stringPtr(v string) *string { + return &v +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 913e1c40..0c607ecc 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -64,6 +64,9 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). + SetNillableLastLoginAt(userIn.LastLoginAt). + SetNillableLastActiveAt(userIn.LastActiveAt). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -151,6 +154,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). SetTotalRecharged(userIn.TotalRecharged) + if userIn.SignupSource != "" { + updateOp = updateOp.SetSignupSource(userIn.SignupSource) + } + if userIn.LastLoginAt != nil { + updateOp = updateOp.SetLastLoginAt(*userIn.LastLoginAt) + } + if userIn.LastActiveAt != nil { + updateOp = updateOp.SetLastActiveAt(*userIn.LastActiveAt) + } if userIn.BalanceNotifyThreshold == nil { updateOp = updateOp.ClearBalanceNotifyThreshold() } @@ -300,6 +312,7 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) var field string defaultField := true + nullsLastField := false switch sortBy { case "email": field = dbuser.FieldEmail @@ -322,6 +335,14 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) case "created_at": field = dbuser.FieldCreatedAt defaultField = false + case "last_login_at": + field = dbuser.FieldLastLoginAt + defaultField = false + nullsLastField = true + case "last_active_at": + field = dbuser.FieldLastActiveAt + defaultField = false + nullsLastField = true default: field = dbuser.FieldID } @@ -330,11 +351,23 @@ func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderNullsLast()).ToFunc(), + dbent.Asc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)} } if defaultField && field == dbuser.FieldID { return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)} } + if nullsLastField { + return []func(*entsql.Selector){ + entsql.OrderByField(field, entsql.OrderDesc(), entsql.OrderNullsLast()).ToFunc(), + dbent.Desc(dbuser.FieldID), + } + } return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)} } @@ -558,10 +591,21 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { return } dst.ID = src.ID + dst.SignupSource = src.SignupSource + dst.LastLoginAt = src.LastLoginAt + dst.LastActiveAt = src.LastActiveAt dst.CreatedAt = src.CreatedAt dst.UpdatedAt = src.UpdatedAt } +func userSignupSourceOrDefault(signupSource string) string { + signupSource = strings.TrimSpace(signupSource) + if signupSource == "" { + return "email" + } + return signupSource +} + // marshalExtraEmails serializes notify email entries to JSON for storage. func marshalExtraEmails(entries []service.NotifyEmailEntry) string { return service.MarshalNotifyEmails(entries) diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go index ab84b0e9..8abef45a 100644 --- a/backend/internal/repository/user_repo_sort_integration_test.go +++ b/backend/internal/repository/user_repo_sort_integration_test.go @@ -4,6 +4,7 @@ package repository import ( "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" @@ -36,4 +37,86 @@ func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() { s.Require().Equal(first.ID, users[1].ID) } +func (s *UserRepoSuite) TestCreateAndRead_PreservesSignupSourceAndActivityTimestamps() { + lastLoginAt := time.Now().Add(-2 * time.Hour).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-30 * time.Minute).UTC().Truncate(time.Microsecond) + + created := s.mustCreateUser(&service.User{ + Email: "identity-meta@example.com", + SignupSource: "github", + LastLoginAt: &lastLoginAt, + LastActiveAt: &lastActiveAt, + }) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("github", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestUpdate_PersistsSignupSourceAndActivityTimestamps() { + created := s.mustCreateUser(&service.User{Email: "identity-update@example.com"}) + lastLoginAt := time.Now().Add(-90 * time.Minute).UTC().Truncate(time.Microsecond) + lastActiveAt := time.Now().Add(-15 * time.Minute).UTC().Truncate(time.Microsecond) + + created.SignupSource = "oidc" + created.LastLoginAt = &lastLoginAt + created.LastActiveAt = &lastActiveAt + + s.Require().NoError(s.repo.Update(s.ctx, created)) + + got, err := s.repo.GetByID(s.ctx, created.ID) + s.Require().NoError(err) + s.Require().Equal("oidc", got.SignupSource) + s.Require().NotNil(got.LastLoginAt) + s.Require().NotNil(got.LastActiveAt) + s.Require().True(got.LastLoginAt.Equal(lastLoginAt)) + s.Require().True(got.LastActiveAt.Equal(lastActiveAt)) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastLoginAtDesc() { + older := time.Now().Add(-4 * time.Hour).UTC().Truncate(time.Microsecond) + newer := time.Now().Add(-1 * time.Hour).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-login@example.com"}) + s.mustCreateUser(&service.User{Email: "older-login@example.com", LastLoginAt: &older}) + s.mustCreateUser(&service.User{Email: "newer-login@example.com", LastLoginAt: &newer}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_login_at", + SortOrder: "desc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("newer-login@example.com", users[0].Email) + s.Require().Equal("older-login@example.com", users[1].Email) + s.Require().Equal("nil-login@example.com", users[2].Email) +} + +func (s *UserRepoSuite) TestListWithFilters_SortByLastActiveAtAsc() { + earlier := time.Now().Add(-3 * time.Hour).UTC().Truncate(time.Microsecond) + later := time.Now().Add(-45 * time.Minute).UTC().Truncate(time.Microsecond) + + s.mustCreateUser(&service.User{Email: "nil-active@example.com"}) + s.mustCreateUser(&service.User{Email: "later-active@example.com", LastActiveAt: &later}) + s.mustCreateUser(&service.User{Email: "earlier-active@example.com", LastActiveAt: &earlier}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "last_active_at", + SortOrder: "asc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 3) + s.Require().Equal("earlier-active@example.com", users[0].Email) + s.Require().Equal("later-active@example.com", users[1].Email) + s.Require().Equal("nil-active@example.com", users[2].Email) +} + func TestUserRepoSortSuiteSmoke(_ *testing.T) {} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index b686b986..e903898f 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -479,7 +479,7 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyOIDCConnectRedirectURL: "", service.SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", service.SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - service.SettingKeyOIDCConnectUsePKCE: "false", + service.SettingKeyOIDCConnectUsePKCE: "true", service.SettingKeyOIDCConnectValidateIDToken: "true", service.SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", service.SettingKeyOIDCConnectClockSkewSeconds: "120", @@ -549,7 +549,7 @@ func TestAPIContracts(t *testing.T) { "oidc_connect_redirect_url": "", "oidc_connect_frontend_redirect_url": "/auth/oidc/callback", "oidc_connect_token_auth_method": "client_secret_post", - "oidc_connect_use_pkce": false, + "oidc_connect_use_pkce": true, "oidc_connect_validate_id_token": true, "oidc_connect_allowed_signing_algs": "RS256,ES256,PS256", "oidc_connect_clock_skew_seconds": 120, diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index c143b030..911a4064 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -64,12 +64,26 @@ func RegisterAuthRoutes( }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) + auth.GET("/oauth/wechat/start", h.Auth.WeChatOAuthStart) + auth.GET("/oauth/wechat/callback", h.Auth.WeChatOAuthCallback) + auth.POST("/oauth/pending/exchange", + rateLimiter.LimitWithOptions("oauth-pending-exchange", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.ExchangePendingOAuthCompletion, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/wechat/complete-registration", + rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteWeChatOAuthRegistration, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 419ddbc3..b802a9c2 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro } func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index fbc856cf..323286b0 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error { return s.deleteErr } +func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + panic("unexpected GetUserAvatar call") +} + +func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + panic("unexpected UpsertUserAvatar call") +} + +func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error { + panic("unexpected DeleteUserAvatar call") +} + func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { panic("unexpected List call") } diff --git a/backend/internal/service/auth_pending_identity_service.go b/backend/internal/service/auth_pending_identity_service.go new file mode 100644 index 00000000..b7e86e12 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service.go @@ -0,0 +1,326 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +var ( + ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found") + ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired") + ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used") + ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid") + ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired") + ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used") + ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session") +) + +const ( + defaultPendingAuthTTL = 15 * time.Minute + defaultPendingAuthCompletionTTL = 5 * time.Minute +) + +type PendingAuthIdentityKey struct { + ProviderType string + ProviderKey string + ProviderSubject string +} + +type CreatePendingAuthSessionInput struct { + SessionToken string + Intent string + Identity PendingAuthIdentityKey + TargetUserID *int64 + RedirectTo string + ResolvedEmail string + RegistrationPasswordHash string + BrowserSessionKey string + UpstreamIdentityClaims map[string]any + LocalFlowState map[string]any + ExpiresAt time.Time +} + +type IssuePendingAuthCompletionCodeInput struct { + PendingAuthSessionID int64 + BrowserSessionKey string + TTL time.Duration +} + +type IssuePendingAuthCompletionCodeResult struct { + Code string + ExpiresAt time.Time +} + +type PendingIdentityAdoptionDecisionInput struct { + PendingAuthSessionID int64 + IdentityID *int64 + AdoptDisplayName bool + AdoptAvatar bool +} + +type AuthPendingIdentityService struct { + entClient *dbent.Client +} + +func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService { + return &AuthPendingIdentityService{entClient: entClient} +} + +func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken := strings.TrimSpace(input.SessionToken) + if sessionToken == "" { + var err error + sessionToken, err = randomOpaqueToken(24) + if err != nil { + return nil, err + } + } + + expiresAt := input.ExpiresAt.UTC() + if expiresAt.IsZero() { + expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL) + } + + create := s.entClient.PendingAuthSession.Create(). + SetSessionToken(sessionToken). + SetIntent(strings.TrimSpace(input.Intent)). + SetProviderType(strings.TrimSpace(input.Identity.ProviderType)). + SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)). + SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)). + SetRedirectTo(strings.TrimSpace(input.RedirectTo)). + SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)). + SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)). + SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)). + SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)). + SetLocalFlowState(copyPendingMap(input.LocalFlowState)). + SetExpiresAt(expiresAt) + if input.TargetUserID != nil { + create = create.SetTargetUserID(*input.TargetUserID) + } + return create.Save(ctx) +} + +func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + + code, err := randomOpaqueToken(24) + if err != nil { + return nil, err + } + ttl := input.TTL + if ttl <= 0 { + ttl = defaultPendingAuthCompletionTTL + } + expiresAt := time.Now().UTC().Add(ttl) + + update := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeHash(hashPendingAuthCode(code)). + SetCompletionCodeExpiresAt(expiresAt) + if strings.TrimSpace(input.BrowserSessionKey) != "" { + update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)) + } + if _, err := update.Save(ctx); err != nil { + return nil, err + } + + return &IssuePendingAuthCompletionCodeResult{ + Code: code, + ExpiresAt: expiresAt, + }, nil +} + +func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode)) + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.CompletionCodeHashEQ(codeHash)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthCodeInvalid + } + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed) +} + +func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + + return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed) +} + +func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + session, err := s.getBrowserSession(ctx, sessionToken) + if err != nil { + return nil, err + } + if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil { + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + sessionToken = strings.TrimSpace(sessionToken) + if sessionToken == "" { + return nil, ErrPendingAuthSessionNotFound + } + + session, err := s.entClient.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(sessionToken)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, ErrPendingAuthSessionNotFound + } + return nil, err + } + return session, nil +} + +func (s *AuthPendingIdentityService) consumeSession( + ctx context.Context, + session *dbent.PendingAuthSession, + browserSessionKey string, + expiredErr error, + consumedErr error, +) (*dbent.PendingAuthSession, error) { + if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil { + return nil, err + } + + now := time.Now().UTC() + updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID). + SetConsumedAt(now). + SetCompletionCodeHash(""). + ClearCompletionCodeExpiresAt(). + Save(ctx) + if err != nil { + return nil, err + } + return updated, nil +} + +func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error { + if session == nil { + return ErrPendingAuthSessionNotFound + } + + now := time.Now().UTC() + if session.ConsumedAt != nil { + return consumedErr + } + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + return expiredErr + } + if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) { + return expiredErr + } + if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) { + return ErrPendingAuthBrowserMismatch + } + return nil +} + +func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) { + if s == nil || s.entClient == nil { + return nil, fmt.Errorf("pending auth ent client is not configured") + } + + existing, err := s.entClient.IdentityAdoptionDecision.Query(). + Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if existing == nil { + create := s.entClient.IdentityAdoptionDecision.Create(). + SetPendingAuthSessionID(input.PendingAuthSessionID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar). + SetDecidedAt(time.Now().UTC()) + if input.IdentityID != nil { + create = create.SetIdentityID(*input.IdentityID) + } + return create.Save(ctx) + } + + update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID). + SetAdoptDisplayName(input.AdoptDisplayName). + SetAdoptAvatar(input.AdoptAvatar) + if input.IdentityID != nil { + update = update.SetIdentityID(*input.IdentityID) + } + return update.Save(ctx) +} + +func copyPendingMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func randomOpaqueToken(byteLen int) (string, error) { + if byteLen <= 0 { + byteLen = 16 + } + buf := make([]byte, byteLen) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func hashPendingAuthCode(code string) string { + sum := sha256.Sum256([]byte(code)) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/service/auth_pending_identity_service_test.go b/backend/internal/service/auth_pending_identity_service_test.go new file mode 100644 index 00000000..c69ebfd2 --- /dev/null +++ b/backend/internal/service/auth_pending_identity_service_test.go @@ -0,0 +1,224 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return NewAuthPendingIdentityService(client), client +} + +func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + targetUser, err := client.User.Create(). + SetEmail("pending-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-123", + }, + TargetUserID: &targetUser.ID, + RedirectTo: "/profile", + ResolvedEmail: "user@example.com", + BrowserSessionKey: "browser-1", + UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"}, + LocalFlowState: map[string]any{"step": "email_required"}, + }) + require.NoError(t, err) + require.NotEmpty(t, session.SessionToken) + require.Equal(t, "bind_current_user", session.Intent) + require.Equal(t, "wechat", session.ProviderType) + require.NotNil(t, session.TargetUserID) + require.Equal(t, targetUser.ID, *session.TargetUserID) + require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"]) + require.Equal(t, "email_required", session.LocalFlowState["step"]) +} + +func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo-main", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expected", + UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"}, + LocalFlowState: map[string]any{"step": "pending"}, + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expected", + }) + require.NoError(t, err) + require.NotEmpty(t, issued.Code) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + require.Empty(t, consumed.CompletionCodeHash) + require.Nil(t, consumed.CompletionCodeExpiresAt) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected") + require.ErrorIs(t, err, ErrPendingAuthCodeInvalid) +} + +func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }, + BrowserSessionKey: "browser-expired", + }) + require.NoError(t, err) + + issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{ + PendingAuthSessionID: session.ID, + BrowserSessionKey: "browser-expired", + TTL: time.Second, + }) + require.NoError(t, err) + + _, err = client.PendingAuthSession.UpdateOneID(session.ID). + SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)). + Save(ctx) + require.NoError(t, err) + + _, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired") + require.ErrorIs(t, err, ErrPendingAuthCodeExpired) +} + +func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) { + svc, client := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("adoption@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + identity, err := client.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("wechat"). + SetProviderKey("wechat-open"). + SetProviderSubject("union-adoption"). + SetMetadata(map[string]any{}). + Save(ctx) + require.NoError(t, err) + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "bind_current_user", + Identity: PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: "wechat-open", + ProviderSubject: "union-adoption", + }, + }) + require.NoError(t, err) + + first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + AdoptDisplayName: true, + AdoptAvatar: false, + }) + require.NoError(t, err) + require.True(t, first.AdoptDisplayName) + require.False(t, first.AdoptAvatar) + require.Nil(t, first.IdentityID) + + second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: session.ID, + IdentityID: &identity.ID, + AdoptDisplayName: true, + AdoptAvatar: true, + }) + require.NoError(t, err) + require.Equal(t, first.ID, second.ID) + require.NotNil(t, second.IdentityID) + require.Equal(t, identity.ID, *second.IdentityID) + require.True(t, second.AdoptAvatar) +} + +func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) { + svc, _ := newAuthPendingIdentityServiceTestClient(t) + ctx := context.Background() + + session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{ + Intent: "login", + Identity: PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "subject-session-token", + }, + BrowserSessionKey: "browser-session", + LocalFlowState: map[string]any{ + "completion_response": map[string]any{ + "access_token": "token", + }, + }, + }) + require.NoError(t, err) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other") + require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch) + + consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.NoError(t, err) + require.NotNil(t, consumed.ConsumedAt) + + _, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session") + require.ErrorIs(t, err, ErrPendingAuthSessionConsumed) +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fd28cd42..962009ce 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -13,6 +13,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -106,6 +107,13 @@ func NewAuthService( } } +func (s *AuthService) EntClient() *dbent.Client { + if s == nil { + return nil + } + return s.entClient +} + // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { return s.RegisterWithVerification(ctx, email, password, "", "", "") @@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } + s.postAuthUserBootstrap(ctx, user, "email", true) s.assignDefaultSubscriptions(ctx, user.ID) // 标记邀请码为已使用(如果使用了邀请码) @@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string if !user.IsActive() { return "", nil, ErrUserNotActive } + s.touchUserLogin(ctx, user.ID) // 生成JWT token token, err := s.GenerateToken(user) @@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) } } else { @@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } + s.touchUserLogin(ctx, user.ID) token, err := s.GenerateToken(user) if err != nil { @@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) } } else { @@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser + s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) s.assignDefaultSubscriptions(ctx, user.ID) if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { @@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } + s.touchUserLogin(ctx, user.ID) tokenPair, err := s.GenerateTokenPair(ctx, user, "") if err != nil { @@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } -// pendingOAuthTokenTTL is the validity period for pending OAuth tokens. -const pendingOAuthTokenTTL = 10 * time.Minute - -// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens. -const pendingOAuthPurpose = "pending_oauth_registration" - -type pendingOAuthClaims struct { - Email string `json:"email"` - Username string `json:"username"` - Purpose string `json:"purpose"` - jwt.RegisteredClaims -} - -// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity -// while waiting for the user to supply an invitation code. -func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) { - now := time.Now() - claims := &pendingOAuthClaims{ - Email: email, - Username: username, - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString([]byte(s.cfg.JWT.Secret)) -} - -// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity. -// Returns ErrInvalidToken when the token is invalid or expired. -func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) { - if len(tokenStr) > maxTokenLength { - return "", "", ErrInvalidToken - } - parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) - token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return []byte(s.cfg.JWT.Secret), nil - }) - if parseErr != nil { - return "", "", ErrInvalidToken - } - claims, ok := token.Claims.(*pendingOAuthClaims) - if !ok || !token.Valid { - return "", "", ErrInvalidToken - } - if claims.Purpose != pendingOAuthPurpose { - return "", "", ErrInvalidToken - } - return claims.Email, claims.Username, nil -} - func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return @@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int } } +func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { + if user == nil || user.ID <= 0 { + return + } + + if strings.TrimSpace(signupSource) == "" { + signupSource = "email" + } + s.updateUserSignupSource(ctx, user.ID, signupSource) + + if signupSource == "email" { + s.ensureEmailAuthIdentity(ctx, user) + } + if touchLogin { + s.touchUserLogin(ctx, user.ID) + } +} + +func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) { + if s == nil || s.entClient == nil || userID <= 0 { + return + } + if strings.TrimSpace(signupSource) == "" { + return + } + if err := s.entClient.User.UpdateOneID(userID). + SetSignupSource(signupSource). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err) + } +} + +func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) { + if s == nil || s.entClient == nil || userID <= 0 { + return + } + now := time.Now().UTC() + if err := s.entClient.User.UpdateOneID(userID). + SetLastLoginAt(now). + SetLastActiveAt(now). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err) + } +} + +func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) { + if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { + return + } + + email := strings.ToLower(strings.TrimSpace(user.Email)) + if email == "" || isReservedEmail(email) { + return + } + + if err := s.entClient.AuthIdentity.Create(). + SetUserID(user.ID). + SetProviderType("email"). + SetProviderKey("email"). + SetProviderSubject(email). + SetVerifiedAt(time.Now().UTC()). + SetMetadata(map[string]any{ + "source": "auth_service_dual_write", + }). + OnConflictColumns( + authidentity.FieldProviderType, + authidentity.FieldProviderKey, + authidentity.FieldProviderSubject, + ). + DoNothing(). + Exec(ctx); err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err) + } +} + +func inferLegacySignupSource(email string) string { + normalized := strings.ToLower(strings.TrimSpace(email)) + switch { + case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain): + return "linuxdo" + case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain): + return "oidc" + case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain): + return "wechat" + default: + return "email" + } +} + func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error { if s.settingService == nil { return nil @@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) { func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || - strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) + strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go new file mode 100644 index 00000000..5bd2b25d --- /dev/null +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -0,0 +1,153 @@ +//go:build unit + +package service_test + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +type authIdentitySettingRepoStub struct { + values map[string]string +} + +func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if v, ok := s.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} + +func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected GetMultiple call") +} + +func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := repository.NewUserRepository(client, db) + cfg := &config.Config{ + JWT: config.JWTConfig{ + Secret: "test-auth-identity-secret", + ExpireHour: 1, + }, + Default: config.DefaultConfig{ + UserBalance: 3.5, + UserConcurrency: 2, + }, + } + settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + }, + }, cfg) + + svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil) + return svc, repo, client +} + +func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) { + svc, _, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + token, user, err := svc.Register(ctx, "user@example.com", "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, user) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.Equal(t, "email", storedUser.SignupSource) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("user@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.NotNil(t, identity.VerifiedAt) +} + +func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { + svc, repo, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + user := &service.User{ + Email: "login@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 1, + Concurrency: 1, + } + require.NoError(t, user.SetPassword("password")) + require.NoError(t, repo.Create(ctx, user)) + + old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second) + _, err := client.User.UpdateOneID(user.ID). + SetLastLoginAt(old). + SetLastActiveAt(old). + Save(ctx) + require.NoError(t, err) + + token, gotUser, err := svc.Login(ctx, user.Email, "password") + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotNil(t, gotUser) + + storedUser, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + require.NotNil(t, storedUser.LastLoginAt) + require.NotNil(t, storedUser.LastActiveAt) + require.True(t, storedUser.LastLoginAt.After(old)) + require.True(t, storedUser.LastActiveAt.After(old)) +} diff --git a/backend/internal/service/auth_service_pending_oauth_test.go b/backend/internal/service/auth_service_pending_oauth_test.go deleted file mode 100644 index 0472e06c..00000000 --- a/backend/internal/service/auth_service_pending_oauth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -//go:build unit - -package service - -import ( - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/require" -) - -func newAuthServiceForPendingOAuthTest() *AuthService { - cfg := &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret-pending-oauth", - ExpireHour: 1, - }, - } - return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) -} - -// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。 -func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - token, err := svc.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - require.NotEmpty(t, token) - - email, username, err := svc.VerifyPendingOAuthToken(token) - require.NoError(t, err) - require.Equal(t, "user@example.com", email) - require.Equal(t, "alice", username) -} - -// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - // 签发一个普通 access token(JWTClaims,无 Purpose 字段) - accessToken, err := svc.GenerateToken(&User{ - ID: 1, - Email: "user@example.com", - Role: RoleUser, - }) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(accessToken) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "some_other_purpose", - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - now := time.Now() - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: "", // 旧 token 无此字段,反序列化后为零值 - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - - past := time.Now().Add(-1 * time.Hour) - claims := &pendingOAuthClaims{ - Email: "user@example.com", - Username: "alice", - Purpose: pendingOAuthPurpose, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(past), - IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)), - }, - } - tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret)) - require.NoError(t, err) - - _, _, err = svc.VerifyPendingOAuthToken(tokenStr) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) { - other := NewAuthService(nil, nil, nil, nil, &config.Config{ - JWT: config.JWTConfig{Secret: "other-secret"}, - }, nil, nil, nil, nil, nil, nil) - - token, err := other.CreatePendingOAuthToken("user@example.com", "alice") - require.NoError(t, err) - - svc := newAuthServiceForPendingOAuthTest() - _, _, err = svc.VerifyPendingOAuthToken(token) - require.ErrorIs(t, err, ErrInvalidToken) -} - -// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。 -func TestVerifyPendingOAuthToken_TooLong(t *testing.T) { - svc := newAuthServiceForPendingOAuthTest() - giant := make([]byte, maxTokenLength+1) - for i := range giant { - giant[i] = 'a' - } - _, _, err := svc.VerifyPendingOAuthToken(string(giant)) - require.ErrorIs(t, err, ErrInvalidToken) -} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index cb452efb..1dddf77e 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。 const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" +// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。 +const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid" + // Setting keys const ( // 注册设置 @@ -153,6 +156,29 @@ const ( SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + // 第三方认证来源默认授予配置 + SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" + SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency" + SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions" + SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup" + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind" + SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance" + SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency" + SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions" + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup" + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind" + SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance" + SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency" + SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions" + SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup" + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind" + SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance" + SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency" + SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" + SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" + SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" + // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 6c09e354..09e60220 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -13,14 +13,30 @@ import ( "sync" "sync/atomic" "time" + + "golang.org/x/sync/singleflight" ) const ( openAIAccountScheduleLayerPreviousResponse = "previous_response_id" openAIAccountScheduleLayerSessionSticky = "session_hash" openAIAccountScheduleLayerLoadBalance = "load_balance" + openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled" +) + +const ( + openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second + openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second ) +type cachedOpenAIAdvancedSchedulerSetting struct { + enabled bool + expiresAt int64 +} + +var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting +var openAIAdvancedSchedulerSettingSF singleflight.Group + type OpenAIAccountScheduleRequest struct { GroupID *int64 SessionHash string @@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler return snapshot } -func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { +func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository { + if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil { + return nil + } + return s.rateLimitService.settingService.settingRepo +} + +func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled + } + } + + result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) { + if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return cached.enabled, nil + } + } + + enabled := false + if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil { + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout) + defer cancel() + + value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey) + if err == nil { + enabled = strings.EqualFold(strings.TrimSpace(value), "true") + } + } + + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: enabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + return enabled, nil + }) + + enabled, _ := result.(bool) + return enabled +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler { if s == nil { return nil } + if !s.isOpenAIAdvancedSchedulerEnabled(ctx) { + return nil + } s.openaiSchedulerOnce.Do(func() { if s.openaiAccountStats == nil { s.openaiAccountStats = newOpenAIAccountRuntimeStats() @@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule return s.openaiScheduler } +func resetOpenAIAdvancedSchedulerSettingCacheForTest() { + openAIAdvancedSchedulerSettingCache = atomic.Value{} + openAIAdvancedSchedulerSettingSF = singleflight.Group{} +} + func (s *OpenAIGatewayService) SelectAccountWithScheduler( ctx context.Context, groupID *int64, @@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( requiredTransport OpenAIUpstreamTransport, ) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { decision := OpenAIAccountScheduleDecision{} - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(ctx) if scheduler == nil { selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) decision.Layer = openAIAccountScheduleLayerLoadBalance @@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( } func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64 } func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return } @@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { } func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { - scheduler := s.getOpenAIAccountScheduler() + scheduler := s.getOpenAIAccountScheduler(context.Background()) if scheduler == nil { return OpenAIAccountSchedulerMetricsSnapshot{} } diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 088815ed..a54f2614 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "math" "sync" @@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct { accountsByID map[int64]*Account } +type schedulerTestOpenAIAccountRepo struct { + AccountRepository + accounts []Account +} + +func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, errors.New("account not found") +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { + var result []Account + for _, acc := range r.accounts { + if acc.Platform == platform { + result = append(result, acc) + } + } + return result, nil +} + +func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) { + return r.ListSchedulableByPlatform(ctx, platform) +} + +type schedulerTestConcurrencyCache struct { + ConcurrencyCache + loadBatchErr error + loadMap map[int64]*AccountLoadInfo + acquireResults map[int64]bool + waitCounts map[int64]int + skipDefaultLoad bool +} + +func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if c.acquireResults != nil { + if result, ok := c.acquireResults[accountID]; ok { + return result, nil + } + } + return true, nil +} + +func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + return nil +} + +func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if c.loadBatchErr != nil { + return nil, c.loadBatchErr + } + out := make(map[int64]*AccountLoadInfo, len(accounts)) + if c.skipDefaultLoad && c.loadMap != nil { + for _, acc := range accounts { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + } + } + return out, nil + } + for _, acc := range accounts { + if c.loadMap != nil { + if load, ok := c.loadMap[acc.ID]; ok { + out[acc.ID] = load + continue + } + } + out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} + } + return out, nil +} + +func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + if c.waitCounts != nil { + if count, ok := c.waitCounts[accountID]; ok { + return count, nil + } + } + return 0, nil +} + +type schedulerTestGatewayCache struct { + sessionBindings map[string]int64 + deletedSessions map[string]int +} + +func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + if id, ok := c.sessionBindings[sessionHash]; ok { + return id, nil + } + return 0, errors.New("not found") +} + +func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + if c.sessionBindings == nil { + c.sessionBindings = make(map[string]int64) + } + c.sessionBindings[sessionHash] = accountID + return nil +} + +func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + if c.sessionBindings == nil { + return nil + } + if c.deletedSessions == nil { + c.deletedSessions = make(map[string]int) + } + c.deletedSessions[sessionHash]++ + delete(c.sessionBindings, sessionHash) + return nil +} + +func newSchedulerTestOpenAIWSV2Config() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} + +type openAIAdvancedSchedulerSettingRepoStub struct { + values map[string]string +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s == nil || s.values == nil { + return "", ErrSettingNotFound + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error { + panic("unexpected call to Set") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + panic("unexpected call to GetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected call to SetMultiple") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected call to GetAll") +} + +func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error { + panic("unexpected call to Delete") +} + +func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + repo := &openAIAdvancedSchedulerSettingRepoStub{ + values: map[string]string{}, + } + if enabled != "" { + repo.values[openAIAdvancedSchedulerSettingKey] = enabled + } + return &RateLimitService{ + settingService: NewSettingService(repo, &config.Config{}), + } +} + func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) { if len(s.snapshotAccounts) == 0 { return nil, false, nil @@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6 return &cloned, nil } +func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10106) + accounts := []Account{ + { + ID: 36001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + }, + { + ID: 36002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cache := &schedulerTestGatewayCache{} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour)) + require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_disabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(36002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + ctx := context.Background() + groupID := int64(10107) + accounts := []Account{ + { + ID: 37001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 37002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.LoadBatchEnabled = false + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour)) + require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_enabled_001", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(37001), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) { ctx := context.Background() groupID := int64(10101) @@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, + cache: cache, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), + } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny) require.NoError(t, err) @@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}} snapshotService := &SchedulerSnapshotService{cache: snapshotCache} - svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService} + svc := &OpenAIGatewayService{ + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, + cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + schedulerSnapshot: snapshotService, + } account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil) require.NoError(t, err) @@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil} dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5} - cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} + cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}} snapshotCache := &openAISnapshotCacheStub{ snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup}, } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}}, cache: cache, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche } snapshotService := &SchedulerSnapshotService{cache: snapshotCache} svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}}, cfg: &config.Config{}, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: snapshotService, } @@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( "openai_apikey_responses_websockets_v2_enabled": true, }, } - cache := &stubGatewayCache{} + cache := &schedulerTestGatewayCache{} cfg := &config.Config{} cfg.Gateway.OpenAIWS.Enabled = true cfg.Gateway.OpenAIWS.OAuthEnabled = true @@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky( cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: cfg, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } store := svc.getOpenAIWSStateStore() @@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_abc": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS Priority: 9, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_sticky_busy": 21001, }, @@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS cfg.Gateway.OpenAIWS.OAuthEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ acquireResults: map[int64]bool{ 21001: false, // sticky 账号已满 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) @@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP "openai_ws_force_http": true, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_force_http": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick }, }, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_ws_only": 2201, }, } - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, @@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, cache: cache, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, - cfg: newOpenAIWSV2TestConfig(), - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, + cfg: newSchedulerTestOpenAIWSV2Config(), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( @@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, @@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { Schedulable: true, Concurrency: 1, } - cache := &stubGatewayCache{ + cache := &schedulerTestGatewayCache{ sessionBindings: map[string]int64{ "openai:session_hash_metrics": account.ID, }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}}, cache: cache, cfg: &config.Config{}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) @@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 - concurrencyCache := stubConcurrencyCache{ + concurrencyCache := schedulerTestConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, @@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA }, } svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: accounts}, - cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts}, + cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), concurrencyService: NewConcurrencyService(concurrencyCache), } @@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { } func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + resetOpenAIAdvancedSchedulerSettingCacheForTest() + svc := &OpenAIGatewayService{} ttft := 120 svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) svc.RecordOpenAIAccountSwitch() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() - require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot) require.Equal(t, 7, svc.openAIWSLBTopK()) require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) @@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) - cfg := newOpenAIWSV2TestConfig() + cfg := newSchedulerTestOpenAIWSV2Config() scheduler.service = &OpenAIGatewayService{cfg: cfg} account := &Account{ ID: 8801, diff --git a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go index c5de8203..ddafc6eb 100644 --- a/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go +++ b/backend/internal/service/openai_account_scheduler_ws_snapshot_test.go @@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool svc := &OpenAIGatewayService{ - accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}}, - cache: &stubGatewayCache{}, + accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}}, + cache: &schedulerTestGatewayCache{}, cfg: cfg, + rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"), schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache}, - concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}), } selection, decision, err := svc.SelectAccountWithScheduler( diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 59764b29..34462a3a 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo SettingHelpImageURL, SettingHelpText, SettingCancelRateLimitOn, SettingCancelRateLimitMax, SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode, + SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource, + SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource, } vals, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { return nil, fmt.Errorf("get payment config settings: %w", err) } cfg := s.parsePaymentConfig(vals) + if s.entClient != nil { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + All(ctx) + if err != nil { + return nil, fmt.Errorf("list enabled provider instances: %w", err) + } + cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances)) + } else { + cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil) + } // Load Stripe publishable key from the first enabled Stripe provider instance cfg.StripePublishableKey = s.getStripePublishableKey(ctx) return cfg, nil @@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy } if raw := vals[SettingEnabledPaymentTypes]; raw != "" { + types := make([]string, 0, len(strings.Split(raw, ","))) for _, t := range strings.Split(raw, ",") { t = strings.TrimSpace(t) if t != "" { - cfg.EnabledTypes = append(cfg.EnabledTypes, t) + types = append(types, t) } } + cfg.EnabledTypes = NormalizeVisibleMethods(types) } return cfg } // getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance. func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string { + if s.entClient == nil { + return "" + } instances, err := s.entClient.PaymentProviderInstance.Query(). Where( paymentproviderinstance.EnabledEQ(true), @@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int { } return v } + +func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool { + available := make(map[string]bool, 4) + for _, inst := range instances { + switch inst.ProviderKey { + case payment.TypeAlipay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) { + available[VisibleMethodSourceOfficialAlipay] = true + } + case payment.TypeWxpay: + if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) { + available[VisibleMethodSourceOfficialWechat] = true + } + case payment.TypeEasyPay: + for _, supportedType := range splitTypes(inst.SupportedTypes) { + switch NormalizeVisibleMethod(supportedType) { + case payment.TypeAlipay: + available[VisibleMethodSourceEasyPayAlipay] = true + case payment.TypeWxpay: + available[VisibleMethodSourceEasyPayWechat] = true + } + } + } + } + return available +} + +func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string { + shouldExpose := map[string]bool{ + payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available), + payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available), + } + + seen := make(map[string]struct{}, len(base)+2) + out := make([]string, 0, len(base)+2) + appendType := func(paymentType string) { + paymentType = NormalizeVisibleMethod(paymentType) + if paymentType == "" { + return + } + if _, ok := seen[paymentType]; ok { + return + } + seen[paymentType] = struct{}{} + out = append(out, paymentType) + } + + for _, paymentType := range base { + visibleMethod := NormalizeVisibleMethod(paymentType) + switch visibleMethod { + case payment.TypeAlipay, payment.TypeWxpay: + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + default: + appendType(visibleMethod) + } + } + + for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} { + if shouldExpose[visibleMethod] { + appendType(visibleMethod) + } + } + return out +} + +func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool { + enabledKey := visibleMethodEnabledSettingKey(method) + sourceKey := visibleMethodSourceSettingKey(method) + if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" { + return false + } + source := NormalizeVisibleMethodSource(method, vals[sourceKey]) + return source != "" && available[source] +} diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go index 027bb796..10919058 100644 --- a/backend/internal/service/payment_config_service_test.go +++ b/backend/internal/service/payment_config_service_test.go @@ -1,9 +1,17 @@ package service import ( + "context" + "database/sql" "testing" + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/internal/payment" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" ) func TestPcParseFloat(t *testing.T) { @@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) { } }) + t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) { + t.Parallel() + vals := map[string]string{ + SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay", + } + cfg := svc.parsePaymentConfig(vals) + if len(cfg.EnabledTypes) != 2 { + t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes)) + } + if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" { + t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes) + } + }) + t.Run("empty enabled types string", func(t *testing.T) { t.Parallel() vals := map[string]string{ @@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) { }) } } + +func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) { + t.Parallel() + + base := []string{"alipay", "wxpay", "stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + } + available := map[string]bool{ + VisibleMethodSourceOfficialAlipay: true, + VisibleMethodSourceOfficialWechat: false, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"alipay", "stripe"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) { + t.Parallel() + + base := []string{"stripe"} + vals := map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay, + } + available := map[string]bool{ + VisibleMethodSourceEasyPayAlipay: true, + } + + got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available) + want := []string{"stripe", "alipay"} + if len(got) != len(want) { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestBuildVisibleMethodSourceAvailability(t *testing.T) { + t.Parallel() + + instances := []*dbent.PaymentProviderInstance{ + {ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"}, + {ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"}, + {ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"}, + } + + got := buildVisibleMethodSourceAvailability(instances) + if !got[VisibleMethodSourceOfficialAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay) + } + if !got[VisibleMethodSourceEasyPayAlipay] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay) + } + if !got[VisibleMethodSourceOfficialWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat) + } + if !got[VisibleMethodSourceEasyPayWechat] { + t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat) + } +} + +func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("EasyPay Alipay"). + SetConfig("{}"). + SetSupportedTypes("alipay"). + SetEnabled(true). + Save(ctx) + if err != nil { + t.Fatalf("create easypay instance: %v", err) + } + + svc := &PaymentConfigService{ + entClient: client, + settingRepo: &paymentConfigSettingRepoStub{ + values: map[string]string{ + SettingEnabledPaymentTypes: "alipay,wxpay,stripe", + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: "easypay", + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: "wxpay", + }, + }, + } + + cfg, err := svc.GetPaymentConfig(ctx) + if err != nil { + t.Fatalf("GetPaymentConfig returned error: %v", err) + } + + want := []string{payment.TypeAlipay, payment.TypeStripe} + if len(cfg.EnabledTypes) != len(want) { + t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes) + } + for i := range want { + if cfg.EnabledTypes[i] != want[i] { + t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes) + } + } +} + +func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared") + if err != nil { + t.Fatalf("open sqlite: %v", err) + } + t.Cleanup(func() { _ = db.Close() }) + + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + t.Fatalf("enable foreign keys: %v", err) + } + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +type paymentConfigSettingRepoStub struct { + values map[string]string +} + +func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) { + return nil, nil +} +func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} +func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil } diff --git a/backend/internal/service/payment_resume_service.go b/backend/internal/service/payment_resume_service.go new file mode 100644 index 00000000..894a8198 --- /dev/null +++ b/backend/internal/service/payment_resume_service.go @@ -0,0 +1,248 @@ +package service + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + PaymentSourceHostedRedirect = "hosted_redirect" + PaymentSourceWechatInAppResume = "wechat_in_app_resume" + + paymentResumeFallbackSigningKey = "sub2api-payment-resume" + + SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source" + SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source" + SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled" + SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled" + + VisibleMethodSourceOfficialAlipay = "official_alipay" + VisibleMethodSourceEasyPayAlipay = "easypay_alipay" + VisibleMethodSourceOfficialWechat = "official_wxpay" + VisibleMethodSourceEasyPayWechat = "easypay_wxpay" +) + +type ResumeTokenClaims struct { + OrderID int64 `json:"oid"` + UserID int64 `json:"uid,omitempty"` + ProviderInstanceID string `json:"pi,omitempty"` + ProviderKey string `json:"pk,omitempty"` + PaymentType string `json:"pt,omitempty"` + CanonicalReturnURL string `json:"ru,omitempty"` + IssuedAt int64 `json:"iat"` +} + +type PaymentResumeService struct { + signingKey []byte +} + +type visibleMethodLoadBalancer struct { + inner payment.LoadBalancer + configService *PaymentConfigService +} + +func NewPaymentResumeService(signingKey []byte) *PaymentResumeService { + return &PaymentResumeService{signingKey: signingKey} +} + +func NormalizeVisibleMethod(method string) string { + return payment.GetBasePaymentType(strings.TrimSpace(method)) +} + +func NormalizeVisibleMethods(methods []string) []string { + if len(methods) == 0 { + return nil + } + seen := make(map[string]struct{}, len(methods)) + out := make([]string, 0, len(methods)) + for _, method := range methods { + normalized := NormalizeVisibleMethod(method) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + return out +} + +func NormalizePaymentSource(source string) string { + switch strings.TrimSpace(strings.ToLower(source)) { + case "", PaymentSourceHostedRedirect: + return PaymentSourceHostedRedirect + case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume: + return PaymentSourceWechatInAppResume + default: + return strings.TrimSpace(strings.ToLower(source)) + } +} + +func NormalizeVisibleMethodSource(method, source string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official": + return VisibleMethodSourceOfficialAlipay + case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayAlipay + } + case payment.TypeWxpay: + switch strings.TrimSpace(strings.ToLower(source)) { + case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official": + return VisibleMethodSourceOfficialWechat + case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay: + return VisibleMethodSourceEasyPayWechat + } + } + return "" +} + +func VisibleMethodProviderKeyForSource(method, source string) (string, bool) { + switch NormalizeVisibleMethodSource(method, source) { + case VisibleMethodSourceOfficialAlipay: + return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceEasyPayAlipay: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay + case VisibleMethodSourceOfficialWechat: + return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay + case VisibleMethodSourceEasyPayWechat: + return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay + default: + return "", false + } +} + +func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer { + if inner == nil || configService == nil || configService.settingRepo == nil { + return inner + } + return &visibleMethodLoadBalancer{inner: inner, configService: configService} +} + +func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) { + return lb.inner.GetInstanceConfig(ctx, instanceID) +} + +func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) { + visibleMethod := NormalizeVisibleMethod(paymentType) + if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) { + return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount) + } + + enabledKey := visibleMethodEnabledSettingKey(visibleMethod) + sourceKey := visibleMethodSourceSettingKey(visibleMethod) + vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey}) + if err != nil { + return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err) + } + if vals[enabledKey] != "true" { + return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod) + } + + targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey]) + if !ok { + return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod) + } + return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount) +} + +func visibleMethodEnabledSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipayEnabled + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpayEnabled + default: + return "" + } +} + +func visibleMethodSourceSettingKey(method string) string { + switch NormalizeVisibleMethod(method) { + case payment.TypeAlipay: + return SettingPaymentVisibleMethodAlipaySource + case payment.TypeWxpay: + return SettingPaymentVisibleMethodWxpaySource + default: + return "" + } +} + +func CanonicalizeReturnURL(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", nil + } + parsed, err := url.Parse(raw) + if err != nil || !parsed.IsAbs() || parsed.Host == "" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https") + } + parsed.Fragment = "" + if parsed.Path == "" { + parsed.Path = "/" + } + return parsed.String(), nil +} + +func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) { + if claims.OrderID <= 0 { + return "", fmt.Errorf("resume token requires order id") + } + if claims.IssuedAt == 0 { + claims.IssuedAt = time.Now().Unix() + } + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshal resume claims: %w", err) + } + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + return encodedPayload + "." + s.sign(encodedPayload), nil +} + +func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) { + parts := strings.Split(token, ".") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed") + } + if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed") + } + var claims ResumeTokenClaims + if err := json.Unmarshal(payload, &claims); err != nil { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid") + } + if claims.OrderID <= 0 { + return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id") + } + return &claims, nil +} + +func (s *PaymentResumeService) sign(payload string) string { + key := s.signingKey + if len(key) == 0 { + key = []byte(paymentResumeFallbackSigningKey) + } + mac := hmac.New(sha256.New, key) + _, _ = mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} diff --git a/backend/internal/service/payment_resume_service_test.go b/backend/internal/service/payment_resume_service_test.go new file mode 100644 index 00000000..e56b4a88 --- /dev/null +++ b/backend/internal/service/payment_resume_service_test.go @@ -0,0 +1,240 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/payment" +) + +func TestNormalizeVisibleMethods(t *testing.T) { + t.Parallel() + + got := NormalizeVisibleMethods([]string{ + "alipay_direct", + "alipay", + " wxpay_direct ", + "wxpay", + "stripe", + }) + + want := []string{"alipay", "wxpay", "stripe"} + if len(got) != len(want) { + t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got) + } + } +} + +func TestNormalizePaymentSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expect string + }{ + {name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect}, + {name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume}, + {name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizePaymentSource(tt.input); got != tt.expect { + t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect) + } + }) + } +} + +func TestCanonicalizeReturnURL(t *testing.T) { + t.Parallel() + + got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a") + if err != nil { + t.Fatalf("CanonicalizeReturnURL returned error: %v", err) + } + if got != "https://example.com/pay/result?b=2" { + t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2") + } +} + +func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) { + t.Parallel() + + if _, err := CanonicalizeReturnURL("/payment/result"); err == nil { + t.Fatal("CanonicalizeReturnURL should reject relative URLs") + } +} + +func TestPaymentResumeTokenRoundTrip(t *testing.T) { + t.Parallel() + + svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := svc.CreateToken(ResumeTokenClaims{ + OrderID: 42, + UserID: 7, + ProviderInstanceID: "19", + ProviderKey: "easypay", + PaymentType: "wxpay", + CanonicalReturnURL: "https://example.com/payment/result", + IssuedAt: 1234567890, + }) + if err != nil { + t.Fatalf("CreateToken returned error: %v", err) + } + + claims, err := svc.ParseToken(token) + if err != nil { + t.Fatalf("ParseToken returned error: %v", err) + } + if claims.OrderID != 42 || claims.UserID != 7 { + t.Fatalf("claims mismatch: %+v", claims) + } + if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" { + t.Fatalf("claims provider snapshot mismatch: %+v", claims) + } + if claims.CanonicalReturnURL != "https://example.com/payment/result" { + t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL) + } +} + +func TestNormalizeVisibleMethodSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + input string + want string + }{ + {name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay}, + {name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay}, + {name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat}, + {name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat}, + {name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want { + t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want) + } + }) + } +} + +func TestVisibleMethodProviderKeyForSource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + source string + want string + ok bool + }{ + {name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true}, + {name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true}, + {name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true}, + {name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true}, + {name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source) + if got != tt.want || ok != tt.ok { + t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok) + } + }) + } +} + +func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + settingRepo: &paymentSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodAlipayEnabled: "true", + SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + _, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5) + if err != nil { + t.Fatalf("SelectInstance returned error: %v", err) + } + if inner.lastProviderKey != payment.TypeAlipay { + t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay) + } +} + +func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) { + t.Parallel() + + inner := &captureLoadBalancer{} + configService := &PaymentConfigService{ + settingRepo: &paymentSettingRepoStub{ + values: map[string]string{ + SettingPaymentVisibleMethodWxpayEnabled: "false", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + }, + }, + } + lb := newVisibleMethodLoadBalancer(inner, configService) + + if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil { + t.Fatal("SelectInstance should reject disabled visible method") + } +} + +type paymentSettingRepoStub struct { + values map[string]string +} + +func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil } +func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + return s.values[key], nil +} +func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil } +func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} +func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil } +func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} +func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil } + +type captureLoadBalancer struct { + lastProviderKey string + lastPaymentType string +} + +func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) { + return map[string]string{}, nil +} + +func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) { + c.lastProviderKey = providerKey + c.lastPaymentType = paymentType + return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil +} diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index 6fc23f97..e897741a 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -65,15 +65,17 @@ func generateRandomString(n int) string { } type CreateOrderRequest struct { - UserID int64 - Amount float64 - PaymentType string - ClientIP string - IsMobile bool - SrcHost string - SrcURL string - OrderType string - PlanID int64 + UserID int64 + Amount float64 + PaymentType string + ClientIP string + IsMobile bool + SrcHost string + SrcURL string + ReturnURL string + PaymentSource string + OrderType string + PlanID int64 } type CreateOrderResponse struct { @@ -88,6 +90,7 @@ type CreateOrderResponse struct { ClientSecret string `json:"client_secret,omitempty"` ExpiresAt time.Time `json:"expires_at"` PaymentMode string `json:"payment_mode,omitempty"` + ResumeToken string `json:"resume_token,omitempty"` } type OrderListParams struct { @@ -165,10 +168,13 @@ type PaymentService struct { configService *PaymentConfigService userRepo UserRepository groupRepo GroupRepository + resumeService *PaymentResumeService } func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService { - return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo} + svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService)) + return svc } // --- Provider Registry --- @@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string { return &s } +func (s *PaymentService) paymentResume() *PaymentResumeService { + if s.resumeService != nil { + return s.resumeService + } + return NewPaymentResumeService(psResumeSigningKey(s.configService)) +} + +func psResumeSigningKey(configService *PaymentConfigService) []byte { + if configService == nil { + return nil + } + return configService.encryptionKey +} + func psSliceContains(sl []string, s string) bool { for _, v := range sl { if v == s { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 7f4a2eb1..de555478 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net/url" + "os" "sort" "strconv" "strings" @@ -114,6 +115,66 @@ type SettingService struct { webSearchManagerBuilder WebSearchManagerBuilder } +type ProviderDefaultGrantSettings struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting + GrantOnSignup bool + GrantOnFirstBind bool +} + +type AuthSourceDefaultSettings struct { + Email ProviderDefaultGrantSettings + LinuxDo ProviderDefaultGrantSettings + OIDC ProviderDefaultGrantSettings + WeChat ProviderDefaultGrantSettings + ForceEmailOnThirdPartySignup bool +} + +type authSourceDefaultKeySet struct { + balance string + concurrency string + subscriptions string + grantOnSignup string + grantOnFirstBind string +} + +var ( + emailAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultEmailBalance, + concurrency: SettingKeyAuthSourceDefaultEmailConcurrency, + subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + } + linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultLinuxDoBalance, + concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency, + subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + } + oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultOIDCBalance, + concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency, + subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + } + weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultWeChatBalance, + concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency, + subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + } +) + +const ( + defaultAuthSourceBalance = 0 + defaultAuthSourceConcurrency = 5 +) + // NewSettingService 创建系统设置服务实例 func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ @@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } + weChatEnabled := isWeChatOAuthConfigured() // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, + WeChatOAuthEnabled: weChatEnabled, BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", PaymentEnabled: settings[SettingPaymentEnabled] == "true", OIDCOAuthEnabled: oidcEnabled, @@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"` PaymentEnabled bool `json:"payment_enabled"` OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"` @@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + WeChatOAuthEnabled: settings.WeChatOAuthEnabled, BackendModeEnabled: settings.BackendModeEnabled, PaymentEnabled: settings.PaymentEnabled, OIDCOAuthEnabled: settings.OIDCOAuthEnabled, @@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } +func isWeChatOAuthConfigured() bool { + openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" && + strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != "" + mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" && + strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != "" + return openConfigured || mpConfigured +} + // safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". func safeRawJSONArray(raw string) json.RawMessage { raw = strings.TrimSpace(raw) @@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS return parseDefaultSubscriptions(value) } +func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) { + keys := []string{ + SettingKeyAuthSourceDefaultEmailBalance, + SettingKeyAuthSourceDefaultEmailConcurrency, + SettingKeyAuthSourceDefaultEmailSubscriptions, + SettingKeyAuthSourceDefaultEmailGrantOnSignup, + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind, + SettingKeyAuthSourceDefaultLinuxDoBalance, + SettingKeyAuthSourceDefaultLinuxDoConcurrency, + SettingKeyAuthSourceDefaultLinuxDoSubscriptions, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup, + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind, + SettingKeyAuthSourceDefaultOIDCBalance, + SettingKeyAuthSourceDefaultOIDCConcurrency, + SettingKeyAuthSourceDefaultOIDCSubscriptions, + SettingKeyAuthSourceDefaultOIDCGrantOnSignup, + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind, + SettingKeyAuthSourceDefaultWeChatBalance, + SettingKeyAuthSourceDefaultWeChatConcurrency, + SettingKeyAuthSourceDefaultWeChatSubscriptions, + SettingKeyAuthSourceDefaultWeChatGrantOnSignup, + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind, + SettingKeyForceEmailOnThirdPartySignup, + } + + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get auth source default settings: %w", err) + } + + return &AuthSourceDefaultSettings{ + Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys), + LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys), + OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys), + WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys), + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", + }, nil +} + +func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error { + if settings == nil { + return nil + } + + for _, subscriptions := range [][]DefaultSubscriptionSetting{ + settings.Email.Subscriptions, + settings.LinuxDo.Subscriptions, + settings.OIDC.Subscriptions, + settings.WeChat.Subscriptions, + } { + if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { + return err + } + } + + updates := make(map[string]string, 21) + writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) + writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) + writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) + writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return fmt.Errorf("update auth source default settings: %w", err) + } + return nil +} + // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyRegistrationEmailSuffixWhitelist: "[]", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeySiteName: "Sub2API", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeyTableDefaultPageSize: "20", - SettingKeyTablePageSizeOptions: "[10,20,50,100]", - SettingKeyCustomMenuItems: "[]", - SettingKeyCustomEndpoints: "[]", - SettingKeyOIDCConnectEnabled: "false", - SettingKeyOIDCConnectProviderName: "OIDC", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeyTableDefaultPageSize: "20", + SettingKeyTablePageSizeOptions: "[10,20,50,100]", + SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", + SettingKeyOIDCConnectEnabled: "false", + SettingKeyOIDCConnectProviderName: "OIDC", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailBalance: "0", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultLinuxDoBalance: "0", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]", + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultOIDCBalance: "0", + SettingKeyAuthSourceDefaultOIDCConcurrency: "5", + SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]", + SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true", + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultWeChatBalance: "0", + SettingKeyAuthSourceDefaultWeChatConcurrency: "5", + SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", + SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true", + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", + SettingKeyForceEmailOnThirdPartySignup: "false", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken } + result.OIDCConnectUsePKCE = true + result.OIDCConnectValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) } else { @@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { return normalized } +func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: defaultAuthSourceBalance, + Concurrency: defaultAuthSourceConcurrency, + Subscriptions: []DefaultSubscriptionSetting{}, + GrantOnSignup: true, + GrantOnFirstBind: false, + } + + if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil { + result.Balance = v + } + if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil { + result.Concurrency = v + } + if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil { + result.Subscriptions = items + } + if raw, ok := settings[keys.grantOnSignup]; ok { + result.GrantOnSignup = raw == "true" + } + if raw, ok := settings[keys.grantOnFirstBind]; ok { + result.GrantOnFirstBind = raw == "true" + } + + return result +} + +func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) { + updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64) + updates[keys.concurrency] = strconv.Itoa(settings.Concurrency) + + subscriptions := settings.Subscriptions + if subscriptions == nil { + subscriptions = []DefaultSubscriptionSetting{} + } + raw, err := json.Marshal(subscriptions) + if err != nil { + raw = []byte("[]") + } + updates[keys.subscriptions] = string(raw) + updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) + updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) +} + func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { defaultPageSize := 20 if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { @@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { effective.RedirectURL = strings.TrimSpace(v) } + effective.UsePKCE = true if !effective.Enabled { return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") @@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } @@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { effective.ValidateIDToken = raw == "true" } + effective.UsePKCE = true + effective.ValidateIDToken = true if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { effective.AllowedSigningAlgs = strings.TrimSpace(v) } @@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config. return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") } case "none": - if !effective.UsePKCE { - return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") - } default: return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") } diff --git a/backend/internal/service/setting_service_auth_source_defaults_test.go b/backend/internal/service/setting_service_auth_source_defaults_test.go new file mode 100644 index 00000000..097bf604 --- /dev/null +++ b/backend/internal/service/setting_service_auth_source_defaults_test.go @@ -0,0 +1,136 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type authSourceDefaultsRepoStub struct { + values map[string]string + updates map[string]string +} + +func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + s.updates = make(map[string]string, len(settings)) + for key, value := range settings { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } + return nil +} + +func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) { + repo := &authSourceDefaultsRepoStub{ + values: map[string]string{ + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true", + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetAuthSourceDefaultSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 12.5, got.Email.Balance) + require.Equal(t, 7, got.Email.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions) + require.False(t, got.Email.GrantOnSignup) + require.False(t, got.Email.GrantOnFirstBind) + require.Equal(t, 0.0, got.LinuxDo.Balance) + require.Equal(t, 5, got.LinuxDo.Concurrency) + require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions) + require.True(t, got.LinuxDo.GrantOnSignup) + require.True(t, got.LinuxDo.GrantOnFirstBind) + require.Equal(t, 5, got.OIDC.Concurrency) + require.Equal(t, 5, got.WeChat.Concurrency) + require.True(t, got.ForceEmailOnThirdPartySignup) +} + +func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) { + repo := &authSourceDefaultsRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{ + Email: ProviderDefaultGrantSettings{ + Balance: 1.25, + Concurrency: 3, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}}, + GrantOnSignup: false, + GrantOnFirstBind: true, + }, + LinuxDo: ProviderDefaultGrantSettings{ + Balance: 2, + Concurrency: 4, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}}, + GrantOnSignup: true, + GrantOnFirstBind: false, + }, + OIDC: ProviderDefaultGrantSettings{ + Balance: 3, + Concurrency: 5, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}}, + GrantOnSignup: true, + GrantOnFirstBind: true, + }, + WeChat: ProviderDefaultGrantSettings{ + Balance: 4, + Concurrency: 6, + Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, + GrantOnSignup: false, + GrantOnFirstBind: false, + }, + ForceEmailOnThirdPartySignup: true, + }) + require.NoError(t, err) + require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance]) + require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency]) + require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup]) + require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind]) + require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup]) + + var got []DefaultSubscriptionSetting + require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got)) + require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ab2eb274..e991ebef 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -152,6 +152,7 @@ type PublicSettings struct { CustomEndpoints string // JSON array of custom endpoints LinuxDoOAuthEnabled bool + WeChatOAuthEnabled bool BackendModeEnabled bool PaymentEnabled bool OIDCOAuthEnabled bool diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 59f8aa6b..d8b5325c 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -7,19 +7,27 @@ import ( ) type User struct { - ID int64 - Email string - Username string - Notes string - PasswordHash string - Role string - Balance float64 - Concurrency int - Status string - AllowedGroups []int64 - TokenVersion int64 // Incremented on password change to invalidate existing tokens - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + Email string + Username string + Notes string + AvatarURL string + AvatarSource string + AvatarMIME string + AvatarByteSize int + AvatarSHA256 string + PasswordHash string + Role string + Balance float64 + Concurrency int + Status string + AllowedGroups []int64 + TokenVersion int64 // Incremented on password change to invalidate existing tokens + SignupSource string + LastLoginAt *time.Time + LastActiveAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3490e804..2f6d9427 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/sha256" "crypto/subtle" + "encoding/base64" + "encoding/hex" "fmt" "log/slog" + "net/url" "strings" "time" @@ -17,10 +21,14 @@ var ( ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") + ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL") + ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller") + ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image") ) const ( - maxNotifyEmails = 3 // Maximum number of notification emails per user + maxNotifyEmails = 3 // Maximum number of notification emails per user + maxInlineAvatarBytes = 100 * 1024 // User-level rate limiting for notify email verification codes notifyCodeUserRateLimit = 5 @@ -47,6 +55,9 @@ type UserRepository interface { GetFirstAdmin(ctx context.Context) (*User, error) Update(ctx context.Context, user *User) error Delete(ctx context.Context, id int64) error + GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) + UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + DeleteUserAvatar(ctx context.Context, userID int64) error List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) @@ -71,11 +82,30 @@ type UserRepository interface { type UpdateProfileRequest struct { Email *string `json:"email"` Username *string `json:"username"` + AvatarURL *string `json:"avatar_url"` Concurrency *int `json:"concurrency"` BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } +type UserAvatar struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + +type UpsertUserAvatarInput struct { + StorageProvider string + StorageKey string + URL string + ContentType string + ByteSize int + SHA256 string +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } @@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat user.Username = *req.Username } + if req.AvatarURL != nil { + avatarValue := strings.TrimSpace(*req.AvatarURL) + switch { + case avatarValue == "": + if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { + return nil, fmt.Errorf("delete avatar: %w", err) + } + applyUserAvatar(user, nil) + default: + avatarInput, err := normalizeUserAvatarInput(avatarValue) + if err != nil { + return nil, err + } + avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) + if err != nil { + return nil, fmt.Errorf("upsert avatar: %w", err) + } + applyUserAvatar(user, avatar) + } + } + if req.Concurrency != nil { user.Concurrency = *req.Concurrency } @@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return user, nil } +func applyUserAvatar(user *User, avatar *UserAvatar) { + if user == nil { + return + } + if avatar == nil { + user.AvatarURL = "" + user.AvatarSource = "" + user.AvatarMIME = "" + user.AvatarByteSize = 0 + user.AvatarSHA256 = "" + return + } + + user.AvatarURL = avatar.URL + user.AvatarSource = avatar.StorageProvider + user.AvatarMIME = avatar.ContentType + user.AvatarByteSize = avatar.ByteSize + user.AvatarSHA256 = avatar.SHA256 +} + +func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.HasPrefix(raw, "data:") { + return normalizeInlineUserAvatarInput(raw) + } + + parsed, err := url.Parse(raw) + if err != nil || parsed == nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if strings.TrimSpace(parsed.Host) == "" { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + return UpsertUserAvatarInput{ + StorageProvider: "remote_url", + URL: raw, + }, nil +} + +func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) { + body := strings.TrimPrefix(raw, "data:") + meta, encoded, ok := strings.Cut(body, ",") + if !ok { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + meta = strings.TrimSpace(meta) + encoded = strings.TrimSpace(encoded) + if !strings.HasSuffix(strings.ToLower(meta), ";base64") { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + + contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")]) + if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return UpsertUserAvatarInput{}, ErrAvatarNotImage + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return UpsertUserAvatarInput{}, ErrAvatarInvalid + } + if len(decoded) > maxInlineAvatarBytes { + return UpsertUserAvatarInput{}, ErrAvatarTooLarge + } + + sum := sha256.Sum256(decoded) + return UpsertUserAvatarInput{ + StorageProvider: "inline", + URL: raw, + ContentType: contentType, + ByteSize: len(decoded), + SHA256: hex.EncodeToString(sum[:]), + }, nil +} + // ChangePassword 修改密码 // Security: Increments TokenVersion to invalidate all existing JWT tokens func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error { @@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) { if err != nil { return nil, fmt.Errorf("get user: %w", err) } + if err := s.hydrateUserAvatar(ctx, user); err != nil { + return nil, fmt.Errorf("get user avatar: %w", err) + } return user, nil } +func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error { + if s == nil || s.userRepo == nil || user == nil || user.ID == 0 { + return nil + } + + avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID) + if err != nil { + return err + } + applyUserAvatar(user, avatar) + return nil +} + // List 获取用户列表(管理员功能) func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index a998d5f4..7d63bb36 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -4,6 +4,9 @@ package service import ( "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" "errors" "sync" "sync/atomic" @@ -19,14 +22,65 @@ import ( type mockUserRepo struct { updateBalanceErr error updateBalanceFn func(ctx context.Context, id int64, amount float64) error + getByIDUser *User + getByIDErr error + updateFn func(ctx context.Context, user *User) error + updateCalls int + upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) + upsertAvatarArgs []UpsertUserAvatarInput + deleteAvatarFn func(ctx context.Context, userID int64) error + deleteAvatarIDs []int64 + getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error) } -func (m *mockUserRepo) Create(context.Context, *User) error { return nil } -func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { + if m.getByIDErr != nil { + return nil, m.getByIDErr + } + if m.getByIDUser != nil { + cloned := *m.getByIDUser + return &cloned, nil + } + return &User{}, nil +} func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) Update(context.Context, *User) error { return nil } -func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) Update(ctx context.Context, user *User) error { + m.updateCalls++ + if m.updateFn != nil { + return m.updateFn(ctx, user) + } + return nil +} +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) { + if m.getAvatarFn != nil { + return m.getAvatarFn(ctx, userID) + } + return nil, nil +} +func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) { + m.upsertAvatarArgs = append(m.upsertAvatarArgs, input) + if m.upsertAvatarFn != nil { + return m.upsertAvatarFn(ctx, userID, input) + } + return &UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil +} +func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID) + if m.deleteAvatarFn != nil { + return m.deleteAvatarFn(ctx, userID) + } + return nil +} func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) { require.Equal(t, auth, svc.authCacheInvalidator) require.Equal(t, cache, svc.billingCache) } + +func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) { + raw := []byte("small-avatar") + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + expectedSum := sha256.Sum256(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 7, + Email: "avatar@example.com", + Username: "avatar-user", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType) + require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256) + require.Equal(t, dataURL, updated.AvatarURL) + require.Equal(t, "inline", updated.AvatarSource) + require.Equal(t, "image/png", updated.AvatarMIME) + require.Equal(t, len(raw), updated.AvatarByteSize) + require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256) +} + +func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) { + raw := make([]byte, maxInlineAvatarBytes+1) + dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw) + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 8, + Email: "large-avatar@example.com", + Username: "too-large", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + _, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{ + AvatarURL: &dataURL, + }) + require.ErrorIs(t, err, ErrAvatarTooLarge) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, repo.deleteAvatarIDs) + require.Zero(t, repo.updateCalls) +} + +func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) { + remoteURL := "https://cdn.example.com/avatar.png" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 9, + Email: "remote-avatar@example.com", + Username: "remote-avatar", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{ + AvatarURL: &remoteURL, + }) + require.NoError(t, err) + require.Len(t, repo.upsertAvatarArgs, 1) + require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider) + require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL) + require.Equal(t, remoteURL, updated.AvatarURL) + require.Equal(t, "remote_url", updated.AvatarSource) + require.Zero(t, updated.AvatarByteSize) +} + +func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) { + empty := "" + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 10, + Email: "delete-avatar@example.com", + Username: "delete-avatar", + AvatarURL: "https://cdn.example.com/old.png", + AvatarSource: "remote_url", + }, + } + svc := NewUserService(repo, nil, nil, nil) + + updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{ + AvatarURL: &empty, + }) + require.NoError(t, err) + require.Equal(t, []int64{10}, repo.deleteAvatarIDs) + require.Empty(t, repo.upsertAvatarArgs) + require.Empty(t, updated.AvatarURL) + require.Empty(t, updated.AvatarSource) +} + +func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) { + repo := &mockUserRepo{ + getByIDUser: &User{ + ID: 12, + Email: "profile-avatar@example.com", + Username: "profile-avatar", + }, + getAvatarFn: func(context.Context, int64) (*UserAvatar, error) { + return &UserAvatar{ + StorageProvider: "remote_url", + URL: "https://cdn.example.com/profile.png", + }, nil + }, + } + svc := NewUserService(repo, nil, nil, nil) + + user, err := svc.GetProfile(context.Background(), 12) + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL) + require.Equal(t, "remote_url", user.AvatarSource) +} diff --git a/backend/migrations/108_auth_identity_foundation_core.sql b/backend/migrations/108_auth_identity_foundation_core.sql new file mode 100644 index 00000000..117e3ca3 --- /dev/null +++ b/backend/migrations/108_auth_identity_foundation_core.sql @@ -0,0 +1,141 @@ +ALTER TABLE users +ADD COLUMN IF NOT EXISTS signup_source VARCHAR(20) NOT NULL DEFAULT 'email', +ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ NULL, +ADD COLUMN IF NOT EXISTS last_active_at TIMESTAMPTZ NULL; + +UPDATE users +SET signup_source = 'email' +WHERE signup_source IS NULL OR signup_source = ''; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'users_signup_source_check' + ) THEN + ALTER TABLE users + ADD CONSTRAINT users_signup_source_check + CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc')); + END IF; +END $$; + +CREATE TABLE IF NOT EXISTS auth_identities ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + verified_at TIMESTAMPTZ NULL, + issuer TEXT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identities_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identities_provider_subject_key + ON auth_identities (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS auth_identities_user_id_idx + ON auth_identities (user_id); + +CREATE INDEX IF NOT EXISTS auth_identities_user_provider_idx + ON auth_identities (user_id, provider_type); + +CREATE TABLE IF NOT EXISTS auth_identity_channels ( + id BIGSERIAL PRIMARY KEY, + identity_id BIGINT NOT NULL REFERENCES auth_identities(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + channel VARCHAR(20) NOT NULL, + channel_app_id TEXT NOT NULL, + channel_subject TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT auth_identity_channels_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_channels_channel_key + ON auth_identity_channels (provider_type, provider_key, channel, channel_app_id, channel_subject); + +CREATE INDEX IF NOT EXISTS auth_identity_channels_identity_id_idx + ON auth_identity_channels (identity_id); + +CREATE TABLE IF NOT EXISTS pending_auth_sessions ( + id BIGSERIAL PRIMARY KEY, + session_token VARCHAR(255) NOT NULL, + intent VARCHAR(40) NOT NULL, + provider_type VARCHAR(20) NOT NULL, + provider_key TEXT NOT NULL, + provider_subject TEXT NOT NULL, + target_user_id BIGINT NULL REFERENCES users(id) ON DELETE SET NULL, + redirect_to TEXT NOT NULL DEFAULT '', + resolved_email TEXT NOT NULL DEFAULT '', + registration_password_hash TEXT NOT NULL DEFAULT '', + upstream_identity_claims JSONB NOT NULL DEFAULT '{}'::jsonb, + local_flow_state JSONB NOT NULL DEFAULT '{}'::jsonb, + browser_session_key TEXT NOT NULL DEFAULT '', + completion_code_hash TEXT NOT NULL DEFAULT '', + completion_code_expires_at TIMESTAMPTZ NULL, + email_verified_at TIMESTAMPTZ NULL, + password_verified_at TIMESTAMPTZ NULL, + totp_verified_at TIMESTAMPTZ NULL, + expires_at TIMESTAMPTZ NOT NULL, + consumed_at TIMESTAMPTZ NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT pending_auth_sessions_intent_check + CHECK (intent IN ('login', 'bind_current_user', 'adopt_existing_user_by_email')), + CONSTRAINT pending_auth_sessions_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS pending_auth_sessions_session_token_key + ON pending_auth_sessions (session_token); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_target_user_id_idx + ON pending_auth_sessions (target_user_id); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_expires_at_idx + ON pending_auth_sessions (expires_at); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_provider_idx + ON pending_auth_sessions (provider_type, provider_key, provider_subject); + +CREATE INDEX IF NOT EXISTS pending_auth_sessions_completion_code_idx + ON pending_auth_sessions (completion_code_hash); + +CREATE TABLE IF NOT EXISTS identity_adoption_decisions ( + id BIGSERIAL PRIMARY KEY, + pending_auth_session_id BIGINT NOT NULL REFERENCES pending_auth_sessions(id) ON DELETE CASCADE, + identity_id BIGINT NULL REFERENCES auth_identities(id) ON DELETE SET NULL, + adopt_display_name BOOLEAN NOT NULL DEFAULT FALSE, + adopt_avatar BOOLEAN NOT NULL DEFAULT FALSE, + decided_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS identity_adoption_decisions_pending_auth_session_id_key + ON identity_adoption_decisions (pending_auth_session_id); + +CREATE INDEX IF NOT EXISTS identity_adoption_decisions_identity_id_idx + ON identity_adoption_decisions (identity_id); + +CREATE TABLE IF NOT EXISTS auth_identity_migration_reports ( + id BIGSERIAL PRIMARY KEY, + report_type VARCHAR(40) NOT NULL, + report_key TEXT NOT NULL, + details JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS auth_identity_migration_reports_type_idx + ON auth_identity_migration_reports (report_type); + +CREATE UNIQUE INDEX IF NOT EXISTS auth_identity_migration_reports_type_key + ON auth_identity_migration_reports (report_type, report_key); diff --git a/backend/migrations/109_auth_identity_compat_backfill.sql b/backend/migrations/109_auth_identity_compat_backfill.sql new file mode 100644 index 00000000..ddbbedbc --- /dev/null +++ b/backend/migrations/109_auth_identity_compat_backfill.sql @@ -0,0 +1,125 @@ +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'email', + 'email', + LOWER(BTRIM(u.email)), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'users.email', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND BTRIM(COALESCE(u.email, '')) <> '' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@linuxdo-connect.invalid')) <> '@linuxdo-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@oidc-connect.invalid')) <> '@oidc-connect.invalid' + AND RIGHT(LOWER(BTRIM(u.email)), LENGTH('@wechat-connect.invalid')) <> '@wechat-connect.invalid' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'linuxdo', + 'linuxdo', + SUBSTRING(BTRIM(u.email) FROM '(?i)^linuxdo-(.+)@linuxdo-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^linuxdo-.+@linuxdo-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +INSERT INTO auth_identities ( + user_id, + provider_type, + provider_key, + provider_subject, + verified_at, + metadata +) +SELECT + u.id, + 'wechat', + 'wechat', + SUBSTRING(BTRIM(u.email) FROM '(?i)^wechat-(.+)@wechat-connect\.invalid$'), + COALESCE(u.updated_at, u.created_at, NOW()), + jsonb_build_object( + 'backfill_source', 'synthetic_email', + 'legacy_email', BTRIM(u.email), + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' +ON CONFLICT (provider_type, provider_key, provider_subject) DO NOTHING; + +UPDATE users +SET signup_source = 'linuxdo' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^linuxdo-.+@linuxdo-connect\.invalid$'; + +UPDATE users +SET signup_source = 'wechat' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^wechat-.+@wechat-connect\.invalid$'; + +UPDATE users +SET signup_source = 'oidc' +WHERE deleted_at IS NULL + AND LOWER(BTRIM(COALESCE(email, ''))) ~ '^oidc-.+@oidc-connect\.invalid$'; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'oidc_synthetic_email_requires_manual_recovery', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'cannot recover issuer_plus_sub deterministically from synthetic email alone', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^oidc-.+@oidc-connect\.invalid$' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_openid_only_requires_remediation', + CAST(u.id AS TEXT), + jsonb_build_object( + 'user_id', u.id, + 'email', LOWER(BTRIM(u.email)), + 'reason', 'legacy wechat synthetic identity requires explicit unionid remediation if channel-only data exists', + 'migration', '109_auth_identity_compat_backfill' + ) +FROM users AS u +WHERE u.deleted_at IS NULL + AND LOWER(BTRIM(u.email)) ~ '^wechat-.+@wechat-connect\.invalid$' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities ai + WHERE ai.user_id = u.id + AND ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + ) +ON CONFLICT (report_type, report_key) DO NOTHING; diff --git a/backend/migrations/110_pending_auth_and_provider_default_grants.sql b/backend/migrations/110_pending_auth_and_provider_default_grants.sql new file mode 100644 index 00000000..fbaed62e --- /dev/null +++ b/backend/migrations/110_pending_auth_and_provider_default_grants.sql @@ -0,0 +1,60 @@ +CREATE TABLE IF NOT EXISTS user_provider_default_grants ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider_type VARCHAR(20) NOT NULL, + grant_reason VARCHAR(20) NOT NULL DEFAULT 'first_bind', + granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT user_provider_default_grants_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc')), + CONSTRAINT user_provider_default_grants_reason_check + CHECK (grant_reason IN ('signup', 'first_bind')) +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_provider_default_grants_user_provider_reason_key + ON user_provider_default_grants (user_id, provider_type, grant_reason); + +CREATE INDEX IF NOT EXISTS user_provider_default_grants_user_id_idx + ON user_provider_default_grants (user_id); + +CREATE TABLE IF NOT EXISTS user_avatars ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + storage_provider VARCHAR(20) NOT NULL DEFAULT 'database', + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL DEFAULT '', + content_type VARCHAR(100) NOT NULL DEFAULT '', + byte_size INT NOT NULL DEFAULT 0, + sha256 VARCHAR(64) NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS user_avatars_user_id_key + ON user_avatars (user_id); + +INSERT INTO settings (key, value) +VALUES + ('auth_source_default_email_balance', '0'), + ('auth_source_default_email_concurrency', '5'), + ('auth_source_default_email_subscriptions', '[]'), + ('auth_source_default_email_grant_on_signup', 'true'), + ('auth_source_default_email_grant_on_first_bind', 'false'), + ('auth_source_default_linuxdo_balance', '0'), + ('auth_source_default_linuxdo_concurrency', '5'), + ('auth_source_default_linuxdo_subscriptions', '[]'), + ('auth_source_default_linuxdo_grant_on_signup', 'true'), + ('auth_source_default_linuxdo_grant_on_first_bind', 'false'), + ('auth_source_default_oidc_balance', '0'), + ('auth_source_default_oidc_concurrency', '5'), + ('auth_source_default_oidc_subscriptions', '[]'), + ('auth_source_default_oidc_grant_on_signup', 'true'), + ('auth_source_default_oidc_grant_on_first_bind', 'false'), + ('auth_source_default_wechat_balance', '0'), + ('auth_source_default_wechat_concurrency', '5'), + ('auth_source_default_wechat_subscriptions', '[]'), + ('auth_source_default_wechat_grant_on_signup', 'true'), + ('auth_source_default_wechat_grant_on_first_bind', 'false'), + ('force_email_on_third_party_signup', 'false') +ON CONFLICT (key) DO NOTHING; + diff --git a/backend/migrations/111_payment_routing_and_scheduler_flags.sql b/backend/migrations/111_payment_routing_and_scheduler_flags.sql new file mode 100644 index 00000000..f222a8d4 --- /dev/null +++ b/backend/migrations/111_payment_routing_and_scheduler_flags.sql @@ -0,0 +1,8 @@ +INSERT INTO settings (key, value) +VALUES + ('payment_visible_method_alipay_source', ''), + ('payment_visible_method_wxpay_source', ''), + ('payment_visible_method_alipay_enabled', 'false'), + ('payment_visible_method_wxpay_enabled', 'false'), + ('openai_advanced_scheduler_enabled', 'false') +ON CONFLICT (key) DO NOTHING; diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts new file mode 100644 index 00000000..574e1e36 --- /dev/null +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -0,0 +1,60 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const post = vi.fn() + +vi.mock('@/api/client', () => ({ + apiClient: { + post + } +})) + +describe('oauth adoption auth api', () => { + beforeEach(() => { + post.mockReset() + post.mockResolvedValue({ data: {} }) + }) + + it('posts adoption decisions when exchanging pending oauth completion', async () => { + const { exchangePendingOAuthCompletion } = await import('@/api/auth') + + await exchangePendingOAuthCompletion({ + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: false, + adopt_avatar: true + }) + }) + + it('posts linuxdo invitation completion with adoption decisions', async () => { + const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') + + await completeLinuxDoOAuthRegistration('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + + it('posts oidc invitation completion with adoption decisions', async () => { + const { completeOIDCOAuthRegistration } = await import('@/api/auth') + + await completeOIDCOAuthRegistration('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) +}) diff --git a/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts new file mode 100644 index 00000000..8756146e --- /dev/null +++ b/frontend/src/api/__tests__/settings.authSourceDefaults.spec.ts @@ -0,0 +1,118 @@ +import { describe, expect, it } from 'vitest' + +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + type UpdateSettingsRequest, +} from '@/api/admin/settings' + +describe('admin settings auth source defaults helpers', () => { + it('builds auth source defaults state from flat settings fields', () => { + const state = buildAuthSourceDefaultsState({ + auth_source_default_email_balance: 9.5, + auth_source_default_email_concurrency: 3, + auth_source_default_email_subscriptions: [ + { group_id: 1, validity_days: 30 }, + ], + auth_source_default_email_grant_on_signup: false, + auth_source_default_email_grant_on_first_bind: true, + auth_source_default_linuxdo_balance: 6, + auth_source_default_linuxdo_concurrency: 8, + auth_source_default_linuxdo_subscriptions: [ + { group_id: 2, validity_days: 60 }, + ], + auth_source_default_linuxdo_grant_on_signup: true, + auth_source_default_linuxdo_grant_on_first_bind: false, + }) + + expect(state.email).toEqual({ + balance: 9.5, + concurrency: 3, + subscriptions: [{ group_id: 1, validity_days: 30 }], + grant_on_signup: false, + grant_on_first_bind: true, + }) + expect(state.linuxdo).toEqual({ + balance: 6, + concurrency: 8, + subscriptions: [{ group_id: 2, validity_days: 60 }], + grant_on_signup: true, + grant_on_first_bind: false, + }) + expect(state.oidc).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: true, + grant_on_first_bind: false, + }) + expect(state.wechat).toEqual({ + balance: 0, + concurrency: 5, + subscriptions: [], + grant_on_signup: true, + grant_on_first_bind: false, + }) + }) + + it('appends auth source defaults back onto update payload', () => { + const payload: UpdateSettingsRequest = { + site_name: 'Sub2API', + } + + appendAuthSourceDefaultsToUpdateRequest(payload, { + email: { + balance: 1.25, + concurrency: 2, + subscriptions: [{ group_id: 3, validity_days: 7 }], + grant_on_signup: true, + grant_on_first_bind: false, + }, + linuxdo: { + balance: 0, + concurrency: 6, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: true, + }, + oidc: { + balance: 4, + concurrency: 9, + subscriptions: [{ group_id: 9, validity_days: 90 }], + grant_on_signup: true, + grant_on_first_bind: true, + }, + wechat: { + balance: 2, + concurrency: 5, + subscriptions: [], + grant_on_signup: false, + grant_on_first_bind: false, + }, + }) + + expect(payload).toMatchObject({ + site_name: 'Sub2API', + auth_source_default_email_balance: 1.25, + auth_source_default_email_concurrency: 2, + auth_source_default_email_subscriptions: [{ group_id: 3, validity_days: 7 }], + auth_source_default_email_grant_on_signup: true, + auth_source_default_email_grant_on_first_bind: false, + auth_source_default_linuxdo_balance: 0, + auth_source_default_linuxdo_concurrency: 6, + auth_source_default_linuxdo_subscriptions: [], + auth_source_default_linuxdo_grant_on_signup: false, + auth_source_default_linuxdo_grant_on_first_bind: true, + auth_source_default_oidc_balance: 4, + auth_source_default_oidc_concurrency: 9, + auth_source_default_oidc_subscriptions: [{ group_id: 9, validity_days: 90 }], + auth_source_default_oidc_grant_on_signup: true, + auth_source_default_oidc_grant_on_first_bind: true, + auth_source_default_wechat_balance: 2, + auth_source_default_wechat_concurrency: 5, + auth_source_default_wechat_subscriptions: [], + auth_source_default_wechat_grant_on_signup: false, + auth_source_default_wechat_grant_on_first_bind: false, + }) + }) +}) diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 1e4a3053..8e182c1c 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -11,6 +11,81 @@ export interface DefaultSubscriptionSetting { validity_days: number } +export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat' + +export interface AuthSourceDefaultsValue { + balance: number + concurrency: number + subscriptions: DefaultSubscriptionSetting[] + grant_on_signup: boolean + grant_on_first_bind: boolean +} + +export type AuthSourceDefaultsState = Record + +const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat'] +const AUTH_SOURCE_DEFAULT_BALANCE = 0 +const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5 + +export function normalizeDefaultSubscriptionSettings( + subscriptions: DefaultSubscriptionSetting[] | null | undefined +): DefaultSubscriptionSetting[] { + if (!Array.isArray(subscriptions)) return [] + + return subscriptions + .filter((item) => item.group_id > 0 && item.validity_days > 0) + .map((item) => ({ + group_id: Math.floor(item.group_id), + validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) + })) +} + +export function buildAuthSourceDefaultsState( + settings: Partial +): AuthSourceDefaultsState { + const raw = settings as Record + + return AUTH_SOURCE_TYPES.reduce((acc, source) => { + const subscriptions = raw[`auth_source_default_${source}_subscriptions`] + acc[source] = { + balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE), + concurrency: Math.max( + 1, + Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY) + ), + subscriptions: normalizeDefaultSubscriptionSettings( + Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : [] + ), + grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false, + grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + } + return acc + }, {} as AuthSourceDefaultsState) +} + +export function appendAuthSourceDefaultsToUpdateRequest( + payload: UpdateSettingsRequest, + authSourceDefaults: AuthSourceDefaultsState +): UpdateSettingsRequest { + const target = payload as Record + + for (const source of AUTH_SOURCE_TYPES) { + const current = authSourceDefaults[source] + target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0 + target[`auth_source_default_${source}_concurrency`] = Math.max( + 1, + Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY) + ) + target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings( + current.subscriptions + ) + target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup + target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind + } + + return payload +} + /** * System settings interface */ @@ -29,6 +104,27 @@ export interface SystemSettings { default_balance: number default_concurrency: number default_subscriptions: DefaultSubscriptionSetting[] + auth_source_default_email_balance?: number + auth_source_default_email_concurrency?: number + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_grant_on_signup?: boolean + auth_source_default_email_grant_on_first_bind?: boolean + auth_source_default_linuxdo_balance?: number + auth_source_default_linuxdo_concurrency?: number + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_linuxdo_grant_on_signup?: boolean + auth_source_default_linuxdo_grant_on_first_bind?: boolean + auth_source_default_oidc_balance?: number + auth_source_default_oidc_concurrency?: number + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_oidc_grant_on_signup?: boolean + auth_source_default_oidc_grant_on_first_bind?: boolean + auth_source_default_wechat_balance?: number + auth_source_default_wechat_concurrency?: number + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_wechat_grant_on_signup?: boolean + auth_source_default_wechat_grant_on_first_bind?: boolean + force_email_on_third_party_signup?: boolean // OEM settings site_name: string site_logo: string @@ -137,6 +233,11 @@ export interface SystemSettings { payment_cancel_rate_limit_window: number payment_cancel_rate_limit_unit: string payment_cancel_rate_limit_window_mode: string + payment_visible_method_alipay_source?: string + payment_visible_method_wxpay_source?: string + payment_visible_method_alipay_enabled?: boolean + payment_visible_method_wxpay_enabled?: boolean + openai_advanced_scheduler_enabled?: boolean // Balance & quota notification balance_low_notify_enabled: boolean @@ -158,6 +259,27 @@ export interface UpdateSettingsRequest { default_balance?: number default_concurrency?: number default_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_balance?: number + auth_source_default_email_concurrency?: number + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_email_grant_on_signup?: boolean + auth_source_default_email_grant_on_first_bind?: boolean + auth_source_default_linuxdo_balance?: number + auth_source_default_linuxdo_concurrency?: number + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_linuxdo_grant_on_signup?: boolean + auth_source_default_linuxdo_grant_on_first_bind?: boolean + auth_source_default_oidc_balance?: number + auth_source_default_oidc_concurrency?: number + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_oidc_grant_on_signup?: boolean + auth_source_default_oidc_grant_on_first_bind?: boolean + auth_source_default_wechat_balance?: number + auth_source_default_wechat_concurrency?: number + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] + auth_source_default_wechat_grant_on_signup?: boolean + auth_source_default_wechat_grant_on_first_bind?: boolean + force_email_on_third_party_signup?: boolean site_name?: string site_logo?: string site_subtitle?: string @@ -245,6 +367,11 @@ export interface UpdateSettingsRequest { payment_cancel_rate_limit_window?: number payment_cancel_rate_limit_unit?: string payment_cancel_rate_limit_window_mode?: string + payment_visible_method_alipay_source?: string + payment_visible_method_wxpay_source?: string + payment_visible_method_alipay_enabled?: boolean + payment_visible_method_wxpay_enabled?: boolean + openai_advanced_scheduler_enabled?: boolean // Balance & quota notification balance_low_notify_enabled?: boolean balance_low_notify_threshold?: number diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index d7abcd6a..10b6ca58 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -198,6 +198,26 @@ export interface PendingOAuthExchangeResponse { suggested_avatar_url?: string } +export interface OAuthAdoptionDecision { + adoptDisplayName?: boolean + adoptAvatar?: boolean +} + +function serializeOAuthAdoptionDecision( + decision?: OAuthAdoptionDecision +): Record { + const payload: Record = {} + + if (typeof decision?.adoptDisplayName === 'boolean') { + payload.adopt_display_name = decision.adoptDisplayName + } + if (typeof decision?.adoptAvatar === 'boolean') { + payload.adopt_avatar = decision.adoptAvatar + } + + return payload +} + /** * Refresh the access token using the refresh token * @returns New token pair @@ -353,7 +373,8 @@ export async function resetPassword(request: ResetPasswordRequest): Promise { const { data } = await apiClient.post<{ access_token: string @@ -361,7 +382,8 @@ export async function completeLinuxDoOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/linuxdo/complete-registration', { - invitation_code: invitationCode + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) }) return data } @@ -372,7 +394,8 @@ export async function completeLinuxDoOAuthRegistration( * @returns Token pair on success */ export async function completeOIDCOAuthRegistration( - invitationCode: string + invitationCode: string, + decision?: OAuthAdoptionDecision ): Promise<{ access_token: string; refresh_token: string; expires_in: number; token_type: string }> { const { data } = await apiClient.post<{ access_token: string @@ -380,13 +403,19 @@ export async function completeOIDCOAuthRegistration( expires_in: number token_type: string }>('/auth/oauth/oidc/complete-registration', { - invitation_code: invitationCode + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) }) return data } -export async function exchangePendingOAuthCompletion(): Promise { - const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) +export async function exchangePendingOAuthCompletion( + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + '/auth/oauth/pending/exchange', + serializeOAuthAdoptionDecision(decision) + ) return data } diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue new file mode 100644 index 00000000..94e20222 --- /dev/null +++ b/frontend/src/components/auth/WechatOAuthSection.vue @@ -0,0 +1,53 @@ + + + diff --git a/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts new file mode 100644 index 00000000..810832a0 --- /dev/null +++ b/frontend/src/components/auth/__tests__/WechatOAuthSection.spec.ts @@ -0,0 +1,74 @@ +import { mount } from '@vue/test-utils' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import WechatOAuthSection from '@/components/auth/WechatOAuthSection.vue' + +const routeState = vi.hoisted(() => ({ + query: {} as Record, +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/login' } as { href: string }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', () => ({ + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oidc.signIn') { + return `Continue with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oauthOrContinue') { + return 'or continue' + } + return key + }, + }), +})) + +describe('WechatOAuthSection', () => { + beforeEach(() => { + routeState.query = { redirect: '/billing?plan=pro' } + locationState.current = { href: 'http://localhost/login' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('starts the open WeChat OAuth flow with the current redirect target', async () => { + const wrapper = mount(WechatOAuthSection) + + expect(wrapper.text()).toContain('WeChat') + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=open&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) + + it('uses mp mode inside the WeChat browser', async () => { + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0 MicroMessenger', + }) + const wrapper = mount(WechatOAuthSection) + + await wrapper.get('button').trigger('click') + + expect(locationState.current.href).toContain( + '/api/v1/auth/oauth/wechat/start?mode=mp&redirect=%2Fbilling%3Fplan%3Dpro' + ) + }) +}) diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue index 974dee66..8f5a5666 100644 --- a/frontend/src/components/payment/PaymentStatusPanel.vue +++ b/frontend/src/components/payment/PaymentStatusPanel.vue @@ -141,7 +141,9 @@ const props = defineProps<{ orderType?: string }>() -const emit = defineEmits<{ done: []; success: [] }>() +type PaymentOutcome = 'success' | 'cancelled' | 'expired' + +const emit = defineEmits<{ done: []; success: []; settled: [outcome: PaymentOutcome] }>() const { t } = useI18n() const paymentStore = usePaymentStore() @@ -154,7 +156,7 @@ const cancelling = ref(false) const paidOrder = ref(null) // Terminal outcome: null = still active, 'success' | 'cancelled' | 'expired' -const outcome = ref<'success' | 'cancelled' | 'expired' | null>(null) +const outcome = ref(null) let pollTimer: ReturnType | null = null let countdownTimer: ReturnType | null = null @@ -194,10 +196,19 @@ const countdownDisplay = computed(() => { function reopenPopup() { if (props.payUrl) { - window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + const win = window.open(props.payUrl, 'paymentPopup', POPUP_WINDOW_FEATURES) + if (!win || win.closed) { + window.location.href = props.payUrl + } } } +function setOutcome(next: PaymentOutcome) { + if (outcome.value === next) return + outcome.value = next + emit('settled', next) +} + async function renderQR() { await nextTick() if (!qrCanvas.value || !qrUrl.value) return @@ -214,23 +225,23 @@ async function pollStatus() { if (order.status === 'COMPLETED' || order.status === 'PAID') { cleanup() paidOrder.value = order - outcome.value = 'success' + setOutcome('success') emit('success') } else if (order.status === 'CANCELLED') { cleanup() - outcome.value = 'cancelled' + setOutcome('cancelled') } else if (order.status === 'EXPIRED' || order.status === 'FAILED') { cleanup() - outcome.value = 'expired' + setOutcome('expired') } } function startCountdown(seconds: number) { remainingSeconds.value = Math.max(0, seconds) - if (remainingSeconds.value <= 0) { outcome.value = 'expired'; return } + if (remainingSeconds.value <= 0) { setOutcome('expired'); return } countdownTimer = setInterval(() => { remainingSeconds.value-- - if (remainingSeconds.value <= 0) { outcome.value = 'expired'; cleanup() } + if (remainingSeconds.value <= 0) { setOutcome('expired'); cleanup() } }, 1000) } @@ -240,7 +251,7 @@ async function handleCancel() { try { await paymentAPI.cancelOrder(props.orderId) cleanup() - outcome.value = 'cancelled' + setOutcome('cancelled') } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } finally { diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts new file mode 100644 index 00000000..f5212f15 --- /dev/null +++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts @@ -0,0 +1,163 @@ +import { describe, expect, it } from 'vitest' +import type { CreateOrderResult, MethodLimit } from '@/types/payment' +import { + decidePaymentLaunch, + getVisibleMethods, + readPaymentRecoverySnapshot, + type PaymentRecoverySnapshot, +} from '@/components/payment/paymentFlow' + +function methodLimit(overrides: Partial = {}): MethodLimit { + return { + daily_limit: 0, + daily_used: 0, + daily_remaining: 0, + single_min: 0, + single_max: 0, + fee_rate: 0, + available: true, + ...overrides, + } +} + +function createOrderResult(overrides: Partial = {}): CreateOrderResult { + return { + order_id: 101, + amount: 88, + pay_amount: 88, + fee_rate: 0, + expires_at: '2099-01-01T00:10:00.000Z', + ...overrides, + } +} + +describe('getVisibleMethods', () => { + it('filters hidden provider methods and normalizes aliases', () => { + const visible = getVisibleMethods({ + alipay_direct: methodLimit({ single_min: 5 }), + wxpay: methodLimit({ single_max: 100 }), + stripe: methodLimit({ fee_rate: 3 }), + }) + + expect(visible).toEqual({ + alipay: methodLimit({ single_min: 5 }), + wxpay: methodLimit({ single_max: 100 }), + }) + }) + + it('prefers canonical visible methods over aliases when both exist', () => { + const visible = getVisibleMethods({ + alipay: methodLimit({ single_min: 2 }), + alipay_direct: methodLimit({ single_min: 9 }), + wxpay_direct: methodLimit({ fee_rate: 1.2 }), + }) + + expect(visible.alipay.single_min).toBe(2) + expect(visible.wxpay.fee_rate).toBe(1.2) + }) +}) + +describe('decidePaymentLaunch', () => { + it('uses Stripe popup waiting flow for desktop Alipay client secret', () => { + const decision = decidePaymentLaunch(createOrderResult({ + client_secret: 'cs_test', + resume_token: 'resume-1', + }), { + visibleMethod: 'alipay', + orderType: 'balance', + isMobile: false, + }) + + expect(decision.kind).toBe('stripe_popup') + expect(decision.paymentState.paymentType).toBe('alipay') + expect(decision.stripeMethod).toBe('alipay') + expect(decision.recovery.resumeToken).toBe('resume-1') + }) + + it('uses Stripe route flow for mobile WeChat client secret', () => { + const decision = decidePaymentLaunch(createOrderResult({ + client_secret: 'cs_test', + }), { + visibleMethod: 'wxpay', + orderType: 'subscription', + isMobile: true, + }) + + expect(decision.kind).toBe('stripe_route') + expect(decision.stripeMethod).toBe('wechat_pay') + expect(decision.paymentState.orderType).toBe('subscription') + }) + + it('keeps hosted redirect metadata for recovery flows', () => { + const decision = decidePaymentLaunch(createOrderResult({ + pay_url: 'https://pay.example.com/session/abc', + payment_mode: 'popup', + resume_token: 'resume-2', + }), { + visibleMethod: 'wxpay', + orderType: 'balance', + isMobile: false, + }) + + expect(decision.kind).toBe('redirect_waiting') + expect(decision.paymentState.payUrl).toBe('https://pay.example.com/session/abc') + expect(decision.recovery.paymentMode).toBe('popup') + expect(decision.recovery.resumeToken).toBe('resume-2') + }) +}) + +describe('readPaymentRecoverySnapshot', () => { + it('restores an unexpired snapshot when the resume token matches', () => { + const snapshot: PaymentRecoverySnapshot = { + orderId: 33, + amount: 18, + qrCode: '', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/session/33', + clientSecret: '', + payAmount: 18, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: 'resume-33', + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), + } + + const restored = readPaymentRecoverySnapshot(JSON.stringify(snapshot), { + now: Date.UTC(2099, 0, 1, 0, 1, 0), + resumeToken: 'resume-33', + }) + + expect(restored?.orderId).toBe(33) + }) + + it('drops expired or mismatched recovery snapshots', () => { + const expiredSnapshot: PaymentRecoverySnapshot = { + orderId: 55, + amount: 18, + qrCode: '', + expiresAt: '2024-01-01T00:10:00.000Z', + paymentType: 'wxpay', + payUrl: 'https://pay.example.com/session/55', + clientSecret: '', + payAmount: 18, + orderType: 'balance', + paymentMode: 'popup', + resumeToken: 'resume-55', + createdAt: Date.UTC(2024, 0, 1, 0, 0, 0), + } + + expect(readPaymentRecoverySnapshot(JSON.stringify(expiredSnapshot), { + now: Date.UTC(2024, 0, 1, 0, 20, 0), + resumeToken: 'resume-55', + })).toBeNull() + + expect(readPaymentRecoverySnapshot(JSON.stringify({ + ...expiredSnapshot, + expiresAt: '2099-01-01T00:10:00.000Z', + }), { + now: Date.UTC(2099, 0, 1, 0, 1, 0), + resumeToken: 'other-token', + })).toBeNull() + }) +}) diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts new file mode 100644 index 00000000..70225a0c --- /dev/null +++ b/frontend/src/components/payment/paymentFlow.ts @@ -0,0 +1,197 @@ +import type { CreateOrderResult, MethodLimit, OrderType } from '@/types/payment' + +export const PAYMENT_RECOVERY_STORAGE_KEY = 'payment.recovery.current' + +const VISIBLE_METHOD_ALIASES = { + alipay: 'alipay', + alipay_direct: 'alipay', + wxpay: 'wxpay', + wxpay_direct: 'wxpay', +} as const + +export type VisiblePaymentMethod = 'alipay' | 'wxpay' +export type StripeVisibleMethod = 'alipay' | 'wechat_pay' +export type PaymentLaunchKind = + | 'qr_waiting' + | 'redirect_waiting' + | 'stripe_popup' + | 'stripe_route' + | 'unhandled' + +export interface PaymentRecoverySnapshot { + orderId: number + amount: number + qrCode: string + expiresAt: string + paymentType: string + payUrl: string + clientSecret: string + payAmount: number + orderType: OrderType | '' + paymentMode: string + resumeToken: string + createdAt: number +} + +export interface PaymentLaunchContext { + visibleMethod: string + orderType: OrderType + isMobile: boolean + now?: number + stripePopupUrl?: string + stripeRouteUrl?: string +} + +export interface PaymentLaunchDecision { + kind: PaymentLaunchKind + paymentState: PaymentRecoverySnapshot + recovery: PaymentRecoverySnapshot + stripeMethod?: StripeVisibleMethod +} + +type CreateOrderFlowResult = CreateOrderResult & { + resume_token?: string +} + +type StorageWriter = Pick + +export function normalizeVisibleMethod(method: string): VisiblePaymentMethod | '' { + const normalized = VISIBLE_METHOD_ALIASES[method.trim() as keyof typeof VISIBLE_METHOD_ALIASES] + return normalized ?? '' +} + +export function getVisibleMethods(methods: Record): Record { + const visible: Record = {} + + Object.entries(methods).forEach(([type, limit]) => { + const normalized = normalizeVisibleMethod(type) + if (!normalized) return + + const isCanonical = type === normalized + const existing = visible[normalized] + if (!existing || isCanonical) { + visible[normalized] = { ...limit } + } + }) + + return visible +} + +export function decidePaymentLaunch( + result: CreateOrderFlowResult, + context: PaymentLaunchContext, +): PaymentLaunchDecision { + const visibleMethod = normalizeVisibleMethod(context.visibleMethod) || context.visibleMethod + const baseState = createPaymentRecoverySnapshot({ + orderId: result.order_id, + amount: result.amount, + qrCode: result.qr_code || '', + expiresAt: result.expires_at || '', + paymentType: visibleMethod, + payUrl: result.pay_url || '', + clientSecret: result.client_secret || '', + payAmount: result.pay_amount, + orderType: context.orderType, + paymentMode: (result.payment_mode || '').trim(), + resumeToken: result.resume_token || '', + }, context.now) + + if (baseState.clientSecret) { + const stripeMethod: StripeVisibleMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay' + const kind: PaymentLaunchKind = stripeMethod === 'alipay' && !context.isMobile + ? 'stripe_popup' + : 'stripe_route' + const payUrl = kind === 'stripe_popup' + ? context.stripePopupUrl || context.stripeRouteUrl || '' + : context.stripeRouteUrl || context.stripePopupUrl || '' + const paymentState = { ...baseState, payUrl } + return { kind, paymentState, recovery: paymentState, stripeMethod } + } + + if (baseState.qrCode) { + return { kind: 'qr_waiting', paymentState: baseState, recovery: baseState } + } + + if (baseState.payUrl) { + return { kind: 'redirect_waiting', paymentState: baseState, recovery: baseState } + } + + return { kind: 'unhandled', paymentState: baseState, recovery: baseState } +} + +export function createPaymentRecoverySnapshot( + state: Omit, + now = Date.now(), +): PaymentRecoverySnapshot { + return { + ...state, + createdAt: now, + } +} + +export function writePaymentRecoverySnapshot( + storage: StorageWriter, + snapshot: PaymentRecoverySnapshot, + key = PAYMENT_RECOVERY_STORAGE_KEY, +): void { + storage.setItem(key, JSON.stringify(snapshot)) +} + +export function clearPaymentRecoverySnapshot( + storage: Pick, + key = PAYMENT_RECOVERY_STORAGE_KEY, +): void { + storage.removeItem(key) +} + +export function readPaymentRecoverySnapshot( + raw: string | null | undefined, + options: { now?: number; resumeToken?: string } = {}, +): PaymentRecoverySnapshot | null { + if (!raw) return null + + try { + const parsed = JSON.parse(raw) as Partial + if ( + typeof parsed.orderId !== 'number' + || typeof parsed.amount !== 'number' + || typeof parsed.qrCode !== 'string' + || typeof parsed.expiresAt !== 'string' + || typeof parsed.paymentType !== 'string' + || typeof parsed.payUrl !== 'string' + || typeof parsed.clientSecret !== 'string' + || typeof parsed.payAmount !== 'number' + || typeof parsed.paymentMode !== 'string' + || typeof parsed.resumeToken !== 'string' + || typeof parsed.createdAt !== 'number' + ) { + return null + } + + const now = options.now ?? Date.now() + const expiresAt = Date.parse(parsed.expiresAt) + if (Number.isFinite(expiresAt) && expiresAt <= now) { + return null + } + if (options.resumeToken && parsed.resumeToken && parsed.resumeToken !== options.resumeToken) { + return null + } + + return { + orderId: parsed.orderId, + amount: parsed.amount, + qrCode: parsed.qrCode, + expiresAt: parsed.expiresAt, + paymentType: parsed.paymentType, + payUrl: parsed.payUrl, + clientSecret: parsed.clientSecret, + payAmount: parsed.payAmount, + orderType: parsed.orderType === 'subscription' ? 'subscription' : 'balance', + paymentMode: parsed.paymentMode, + resumeToken: parsed.resumeToken, + createdAt: parsed.createdAt, + } + } catch { + return null + } +} diff --git a/frontend/src/router/__tests__/wechat-route.spec.ts b/frontend/src/router/__tests__/wechat-route.spec.ts new file mode 100644 index 00000000..84b20452 --- /dev/null +++ b/frontend/src/router/__tests__/wechat-route.spec.ts @@ -0,0 +1,55 @@ +import { describe, expect, it, vi } from 'vitest' + +const authStore = vi.hoisted(() => ({ + checkAuth: vi.fn(), + isAuthenticated: false, + isAdmin: false, + isSimpleMode: false, +})) + +const appStore = vi.hoisted(() => ({ + siteName: 'Sub2API', + backendModeEnabled: false, + cachedPublicSettings: null as null | Record, +})) + +vi.mock('@/stores/auth', () => ({ + useAuthStore: () => authStore, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => appStore, +})) + +vi.mock('@/stores/adminSettings', () => ({ + useAdminSettingsStore: () => ({ + customMenuItems: [], + }), +})) + +vi.mock('@/composables/useNavigationLoading', () => ({ + useNavigationLoadingState: () => ({ + startNavigation: vi.fn(), + endNavigation: vi.fn(), + isLoading: { value: false }, + }), +})) + +vi.mock('@/composables/useRoutePrefetch', () => ({ + useRoutePrefetch: () => ({ + triggerPrefetch: vi.fn(), + cancelPendingPrefetch: vi.fn(), + resetPrefetchState: vi.fn(), + }), +})) + +describe('router WeChat OAuth route', () => { + it('registers the WeChat callback route as a public route', async () => { + const { default: router } = await import('@/router') + const route = router.getRoutes().find((record) => record.name === 'WeChatOAuthCallback') + + expect(route?.path).toBe('/auth/wechat/callback') + expect(route?.meta.requiresAuth).toBe(false) + expect(route?.meta.title).toBe('WeChat OAuth Callback') + }) +}) diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index ad6e71c4..beaa1da2 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -83,6 +83,15 @@ const routes: RouteRecordRaw[] = [ title: 'LinuxDo OAuth Callback' } }, + { + path: '/auth/wechat/callback', + name: 'WeChatOAuthCallback', + component: () => import('@/views/auth/WechatCallbackView.vue'), + meta: { + requiresAuth: false, + title: 'WeChat OAuth Callback' + } + }, { path: '/auth/oidc/callback', name: 'OIDCOAuthCallback', diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 1995383d..1b1af87b 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -336,6 +336,7 @@ export const useAppStore = defineStore('app', () => { custom_menu_items: [], custom_endpoints: [], linuxdo_oauth_enabled: false, + wechat_oauth_enabled: false, oidc_oauth_enabled: false, oidc_oauth_provider_name: 'OIDC', backend_mode_enabled: false, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 89fd777f..529eff55 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -123,6 +123,7 @@ export interface PublicSettings { custom_menu_items: CustomMenuItem[] custom_endpoints: CustomEndpoint[] linuxdo_oauth_enabled: boolean + wechat_oauth_enabled: boolean oidc_oauth_enabled: boolean oidc_oauth_provider_name: string backend_mode_enabled: boolean diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 8bfa0f2b..0d23baa5 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1586,6 +1586,221 @@
+ +
+
+

+ {{ localText('认证来源默认值', 'Auth Source Defaults') }} +

+

+ {{ + localText( + '按注册来源配置新用户默认余额、并发、订阅与授权策略。', + 'Configure per-source default balance, concurrency, subscriptions, and grant rules.' + ) + }} +

+
+
+
+
+ +

+ {{ + localText( + '启用后,Linux DO、OIDC、微信注册缺少邮箱时必须先补充邮箱地址。', + 'When enabled, Linux DO, OIDC, and WeChat signups must provide an email before account creation.' + ) + }} +

+
+ +
+ +
+
+
+
{{ authSource.title }}
+

+ {{ authSource.description }} +

+
+ +
+
+ + +
+
+ + +
+
+ +
+
+
+ +

+ {{ + localText( + '来源首次注册成功后立即发放默认权益。', + 'Grant default entitlements immediately after signup.' + ) + }} +

+
+ +
+ +
+
+ +

+ {{ + localText( + '来源首次绑定到现有账号时发放默认权益。', + 'Grant default entitlements when the source is first bound to an existing user.' + ) + }} +

+
+ +
+
+ +
+
+
+ +

+ {{ + localText( + '仅对当前认证来源生效,未配置时不追加来源专属订阅。', + 'Applies only to this auth source. Leave empty to skip source-specific subscriptions.' + ) + }} +

+
+ +
+ +
+ {{ + localText( + '当前来源未配置专属默认订阅。', + 'No source-specific default subscriptions configured.' + ) + }} +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+
+
+
+
+
@@ -1643,19 +1858,38 @@

-
-
-
@@ -2450,6 +2684,59 @@

+
+
+
+
+ +

+ {{ + localText( + '控制前台结算页是否展示该方式,以及展示时使用的来源键。', + 'Controls whether checkout shows this method and which source key it exposes.' + ) + }} +

+
+ +
+ +
+ + +

+ {{ + localText( + '留空表示由后端使用默认来源;可填 easypay、alipay、wxpay 等来源标识。', + 'Leave blank to let the backend decide. Typical values are easypay, alipay, or wxpay.' + ) + }} +

+
+
+
@@ -2827,7 +3114,14 @@ import { ref, reactive, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { adminAPI } from '@/api' +import { + appendAuthSourceDefaultsToUpdateRequest, + buildAuthSourceDefaultsState, + normalizeDefaultSubscriptionSettings, +} from '@/api/admin/settings' import type { + AuthSourceDefaultsState, + AuthSourceType, SystemSettings, UpdateSettingsRequest, DefaultSubscriptionSetting, @@ -2864,6 +3158,10 @@ const { t, locale } = useI18n() const appStore = useAppStore() const adminSettingsStore = useAdminSettingsStore() +function localText(zh: string, en: string): string { + return locale.value.startsWith('zh') ? zh : en +} + type SettingsTab = 'general' | 'security' | 'users' | 'gateway' | 'payment' | 'email' | 'backup' const activeTab = ref('general') const settingsTabs = [ @@ -2960,6 +3258,12 @@ type SettingsForm = SystemSettings & { turnstile_secret_key: string linuxdo_connect_client_secret: string oidc_connect_client_secret: string + force_email_on_third_party_signup: boolean + payment_visible_method_alipay_source: string + payment_visible_method_wxpay_source: string + payment_visible_method_alipay_enabled: boolean + payment_visible_method_wxpay_enabled: boolean + openai_advanced_scheduler_enabled: boolean } const form = reactive({ @@ -2974,6 +3278,7 @@ const form = reactive({ default_balance: 0, default_concurrency: 1, default_subscriptions: [], + force_email_on_third_party_signup: false, site_name: 'Sub2API', site_logo: '', site_subtitle: 'Subscription to API Conversion Platform', @@ -2983,7 +3288,7 @@ const form = reactive({ home_content: '', backend_mode_enabled: false, hide_ccs_import_button: false, - payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', + payment_enabled: false, payment_min_amount: 1, payment_max_amount: 10000, payment_daily_limit: 50000, payment_max_pending_orders: 3, payment_order_timeout_minutes: 30, payment_balance_disabled: false, payment_balance_recharge_multiplier: 1, payment_recharge_fee_rate: 0, payment_enabled_types: [], payment_help_image_url: '', payment_help_text: '', payment_product_name_prefix: '', payment_product_name_suffix: '', payment_load_balance_strategy: 'round-robin', payment_cancel_rate_limit_enabled: false, payment_cancel_rate_limit_max: 10, payment_cancel_rate_limit_window: 1, payment_cancel_rate_limit_unit: 'day', payment_cancel_rate_limit_window_mode: 'rolling', payment_visible_method_alipay_source: '', payment_visible_method_wxpay_source: '', payment_visible_method_alipay_enabled: false, payment_visible_method_wxpay_enabled: false, table_default_page_size: tablePageSizeDefault, table_page_size_options: [10, 20, 50, 100], custom_menu_items: [] as Array<{id: string; label: string; icon_svg: string; url: string; visibility: 'user' | 'admin'; sort_order: number}>, @@ -3051,6 +3356,7 @@ const form = reactive({ max_claude_code_version: '', // 分组隔离 allow_ungrouped_key_scheduling: false, + openai_advanced_scheduler_enabled: false, // Gateway forwarding behavior enable_fingerprint_unification: true, enable_metadata_passthrough: false, @@ -3063,6 +3369,74 @@ const form = reactive({ account_quota_notify_emails: [] as NotifyEmailEntry[] }) +const authSourceDefaults = reactive(buildAuthSourceDefaultsState({})) + +const authSourceDefaultsMeta = computed(() => [ + { + source: 'email' as AuthSourceType, + title: localText('邮箱注册', 'Email signup'), + description: localText('适用于邮箱密码注册的新用户默认配额。', 'Default quota grants for email-password signups.') + }, + { + source: 'linuxdo' as AuthSourceType, + title: localText('Linux DO 登录', 'Linux DO signup'), + description: localText('适用于 Linux DO 第三方注册的新用户默认配额。', 'Default quota grants for Linux DO signups.') + }, + { + source: 'oidc' as AuthSourceType, + title: localText('OIDC 登录', 'OIDC signup'), + description: localText('适用于 OIDC 第三方注册的新用户默认配额。', 'Default quota grants for OIDC signups.') + }, + { + source: 'wechat' as AuthSourceType, + title: localText('微信登录', 'WeChat signup'), + description: localText('适用于微信第三方注册的新用户默认配额。', 'Default quota grants for WeChat signups.') + }, +]) + +const paymentVisibleMethodCards = computed(() => [ + { + key: 'alipay' as const, + title: t('payment.methods.alipay'), + enabledField: 'payment_visible_method_alipay_enabled' as const, + sourceField: 'payment_visible_method_alipay_source' as const, + }, + { + key: 'wxpay' as const, + title: t('payment.methods.wxpay'), + enabledField: 'payment_visible_method_wxpay_enabled' as const, + sourceField: 'payment_visible_method_wxpay_source' as const, + }, +]) + +function getPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay'): boolean { + return method === 'alipay' + ? form.payment_visible_method_alipay_enabled + : form.payment_visible_method_wxpay_enabled +} + +function setPaymentVisibleMethodEnabled(method: 'alipay' | 'wxpay', enabled: boolean) { + if (method === 'alipay') { + form.payment_visible_method_alipay_enabled = enabled + return + } + form.payment_visible_method_wxpay_enabled = enabled +} + +function getPaymentVisibleMethodSource(method: 'alipay' | 'wxpay'): string { + return method === 'alipay' + ? form.payment_visible_method_alipay_source + : form.payment_visible_method_wxpay_source +} + +function setPaymentVisibleMethodSource(method: 'alipay' | 'wxpay', source: string) { + if (method === 'alipay') { + form.payment_visible_method_alipay_source = source + return + } + form.payment_visible_method_wxpay_source = source +} + // Proxies for web search emulation ProxySelector const webSearchProxies = ref([]) @@ -3428,15 +3802,9 @@ async function loadSettings() { (form as Record)[key] = value } } + Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(settings)) form.backend_mode_enabled = settings.backend_mode_enabled - form.default_subscriptions = Array.isArray(settings.default_subscriptions) - ? settings.default_subscriptions - .filter((item) => item.group_id > 0 && item.validity_days > 0) - .map((item) => ({ - group_id: item.group_id, - validity_days: item.validity_days - })) - : [] + form.default_subscriptions = normalizeDefaultSubscriptionSettings(settings.default_subscriptions) registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( settings.registration_email_suffix_whitelist ) @@ -3471,10 +3839,18 @@ async function loadSubscriptionGroups() { } } +function findNextAvailableSubscriptionGroup( + existingGroupIDs: number[] +): AdminGroup | undefined { + const existing = new Set(existingGroupIDs) + return subscriptionGroups.value.find((group) => !existing.has(group.id)) +} + function addDefaultSubscription() { if (subscriptionGroups.value.length === 0) return - const existing = new Set(form.default_subscriptions.map((item) => item.group_id)) - const candidate = subscriptionGroups.value.find((group) => !existing.has(group.id)) + const candidate = findNextAvailableSubscriptionGroup( + form.default_subscriptions.map((item) => item.group_id) + ) if (!candidate) return form.default_subscriptions.push({ group_id: candidate.id, @@ -3486,6 +3862,36 @@ function removeDefaultSubscription(index: number) { form.default_subscriptions.splice(index, 1) } +function addAuthSourceDefaultSubscription(source: AuthSourceType) { + if (subscriptionGroups.value.length === 0) return + const candidate = findNextAvailableSubscriptionGroup( + authSourceDefaults[source].subscriptions.map((item) => item.group_id) + ) + if (!candidate) return + authSourceDefaults[source].subscriptions.push({ + group_id: candidate.id, + validity_days: 30 + }) +} + +function removeAuthSourceDefaultSubscription(source: AuthSourceType, index: number) { + authSourceDefaults[source].subscriptions.splice(index, 1) +} + +function findDuplicateDefaultSubscription( + subscriptions: DefaultSubscriptionSetting[] +): DefaultSubscriptionSetting | undefined { + const seenGroupIDs = new Set() + + return subscriptions.find((item) => { + if (seenGroupIDs.has(item.group_id)) { + return true + } + seenGroupIDs.add(item.group_id) + return false + }) +} + async function saveSettings() { saving.value = true try { @@ -3520,21 +3926,12 @@ async function saveSettings() { form.table_default_page_size = normalizedTableDefaultPageSize form.table_page_size_options = normalizedTablePageSizeOptions - const normalizedDefaultSubscriptions = form.default_subscriptions - .filter((item) => item.group_id > 0 && item.validity_days > 0) - .map((item: DefaultSubscriptionSetting) => ({ - group_id: item.group_id, - validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) - })) - - const seenGroupIDs = new Set() - const duplicateDefaultSubscription = normalizedDefaultSubscriptions.find((item) => { - if (seenGroupIDs.has(item.group_id)) { - return true - } - seenGroupIDs.add(item.group_id) - return false - }) + const normalizedDefaultSubscriptions = normalizeDefaultSubscriptionSettings( + form.default_subscriptions + ) + const duplicateDefaultSubscription = findDuplicateDefaultSubscription( + normalizedDefaultSubscriptions + ) if (duplicateDefaultSubscription) { appStore.showError( t('admin.settings.defaults.defaultSubscriptionsDuplicate', { @@ -3544,6 +3941,23 @@ async function saveSettings() { return } + for (const authSource of authSourceDefaultsMeta.value) { + authSourceDefaults[authSource.source].subscriptions = normalizeDefaultSubscriptionSettings( + authSourceDefaults[authSource.source].subscriptions + ) + const duplicate = findDuplicateDefaultSubscription( + authSourceDefaults[authSource.source].subscriptions + ) + if (duplicate) { + appStore.showError( + `${authSource.title}: ${t('admin.settings.defaults.defaultSubscriptionsDuplicate', { + groupId: duplicate.group_id + })}` + ) + return + } + } + // Validate URL fields — novalidate disables browser-native checks, so we validate here const isValidHttpUrl = (url: string): boolean => { if (!url) return true @@ -3571,6 +3985,7 @@ async function saveSettings() { default_balance: form.default_balance, default_concurrency: form.default_concurrency, default_subscriptions: normalizedDefaultSubscriptions, + force_email_on_third_party_signup: form.force_email_on_third_party_signup, site_name: form.site_name, site_logo: form.site_logo, site_subtitle: form.site_subtitle, @@ -3655,6 +4070,11 @@ async function saveSettings() { payment_cancel_rate_limit_window: Number(form.payment_cancel_rate_limit_window) || 1, payment_cancel_rate_limit_unit: form.payment_cancel_rate_limit_unit, payment_cancel_rate_limit_window_mode: form.payment_cancel_rate_limit_window_mode, + payment_visible_method_alipay_source: form.payment_visible_method_alipay_source, + payment_visible_method_wxpay_source: form.payment_visible_method_wxpay_source, + payment_visible_method_alipay_enabled: form.payment_visible_method_alipay_enabled, + payment_visible_method_wxpay_enabled: form.payment_visible_method_wxpay_enabled, + openai_advanced_scheduler_enabled: form.openai_advanced_scheduler_enabled, // Balance & quota notification balance_low_notify_enabled: form.balance_low_notify_enabled, balance_low_notify_threshold: Number(form.balance_low_notify_threshold) || 0, @@ -3663,12 +4083,15 @@ async function saveSettings() { account_quota_notify_emails: (form.account_quota_notify_emails || []).filter((e) => e.email.trim() !== ''), } + appendAuthSourceDefaultsToUpdateRequest(payload, authSourceDefaults) + const updated = await adminAPI.settings.updateSettings(payload) for (const [key, value] of Object.entries(updated)) { if (value !== null && value !== undefined) { (form as Record)[key] = value } } + Object.assign(authSourceDefaults, buildAuthSourceDefaultsState(updated)) registrationEmailSuffixWhitelistTags.value = normalizeRegistrationEmailSuffixDomains( updated.registration_email_suffix_whitelist ) diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index af48959b..0a125def 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -11,32 +11,94 @@
-
-

- {{ t('auth.linuxdo.invitationRequired') }} -

-
- +
+
+
+
+

+ Use LinuxDo profile details +

+

+ Choose whether to apply the nickname or avatar from LinuxDo to this account. +

+
+ + + + +
- -

- {{ invitationError }} + + + +

@@ -71,7 +133,12 @@ import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' import Icon from '@/components/icons/Icon.vue' import { useAuthStore, useAppStore } from '@/stores' -import { completeLinuxDoOAuthRegistration } from '@/api/auth' +import { + completeLinuxDoOAuthRegistration, + exchangePendingOAuthCompletion, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse +} from '@/api/auth' const route = useRoute() const router = useRouter() @@ -85,11 +152,16 @@ const errorMessage = ref('') // Invitation code flow state const needsInvitation = ref(false) -const pendingOAuthToken = ref('') const invitationCode = ref('') const isSubmitting = ref(false) const invitationError = ref('') const redirectTo = ref('/dashboard') +const adoptionRequired = ref(false) +const suggestedDisplayName = ref('') +const suggestedAvatarUrl = ref('') +const adoptDisplayName = ref(true) +const adoptAvatar = ref(true) +const needsAdoptionConfirmation = ref(false) function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -106,6 +178,54 @@ function sanitizeRedirectPath(path: string | null | undefined): string { return path } +function currentAdoptionDecision(): OAuthAdoptionDecision { + return { + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value + } +} + +function applyAdoptionSuggestionState(completion: { + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +}) { + adoptionRequired.value = completion.adoption_required === true + suggestedDisplayName.value = completion.suggested_display_name || '' + suggestedAvatarUrl.value = completion.suggested_avatar_url || '' + + if (!suggestedDisplayName.value) { + adoptDisplayName.value = false + } + if (!suggestedAvatarUrl.value) { + adoptAvatar.value = false + } +} + +function hasSuggestedProfile(completion: { + suggested_display_name?: string + suggested_avatar_url?: string +}): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { + if (!completion.access_token) { + throw new Error(t('auth.linuxdo.callbackMissingToken')) + } + + if (completion.refresh_token) { + localStorage.setItem('refresh_token', completion.refresh_token) + } + if (completion.expires_in) { + localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + } + + await authStore.setToken(completion.access_token) + appStore.showSuccess(t('auth.loginSuccess')) + await router.replace(redirect) +} + async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return @@ -113,8 +233,8 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { const tokenData = await completeLinuxDoOAuthRegistration( - pendingOAuthToken.value, - invitationCode.value.trim() + invitationCode.value.trim(), + currentAdoptionDecision() ) if (tokenData.refresh_token) { localStorage.setItem('refresh_token', tokenData.refresh_token) @@ -134,63 +254,65 @@ async function handleSubmitInvitation() { } } +async function handleContinueLogin() { + isSubmitting.value = true + try { + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeLogin(completion, redirectTo.value) + } catch (e: unknown) { + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') + appStore.showError(errorMessage.value) + needsAdoptionConfirmation.value = false + } finally { + isSubmitting.value = false + } +} + onMounted(async () => { const params = parseFragmentParams() - - const token = params.get('access_token') || '' - const refreshToken = params.get('refresh_token') || '' - const expiresInStr = params.get('expires_in') || '' - const redirect = sanitizeRedirectPath( - params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' - ) const error = params.get('error') const errorDesc = params.get('error_description') || params.get('error_message') || '' if (error) { - if (error === 'invitation_required') { - pendingOAuthToken.value = params.get('pending_oauth_token') || '' - redirectTo.value = sanitizeRedirectPath(params.get('redirect')) - if (!pendingOAuthToken.value) { - errorMessage.value = t('auth.linuxdo.invalidPendingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - needsInvitation.value = true - isProcessing.value = false - return - } errorMessage.value = errorDesc || error appStore.showError(errorMessage.value) isProcessing.value = false return } - if (!token) { - errorMessage.value = t('auth.linuxdo.callbackMissingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - try { - // Store refresh token and expires_at (convert to timestamp) if provided - if (refreshToken) { - localStorage.setItem('refresh_token', refreshToken) + const completion = await exchangePendingOAuthCompletion() + const redirect = sanitizeRedirectPath( + completion.redirect || (route.query.redirect as string | undefined) || '/dashboard' + ) + applyAdoptionSuggestionState(completion) + redirectTo.value = redirect + + if (completion.error === 'invitation_required') { + needsInvitation.value = true + isProcessing.value = false + return } - if (expiresInStr) { - const expiresIn = parseInt(expiresInStr, 10) - if (!isNaN(expiresIn)) { - localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) - } + + if (adoptionRequired.value && hasSuggestedProfile(completion)) { + needsAdoptionConfirmation.value = true + isProcessing.value = false + return } - await authStore.setToken(token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirect) + await finalizeLogin(completion, redirect) } catch (e: unknown) { - const err = e as { message?: string; response?: { data?: { detail?: string } } } - errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed') + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') appStore.showError(errorMessage.value) isProcessing.value = false } @@ -209,4 +331,3 @@ onMounted(async () => { transform: translateY(-8px); } - diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 70b64e3f..fa4ac34c 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,12 +11,17 @@

-
+
+ (false) const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') const linuxdoOAuthEnabled = ref(false) +const wechatOAuthEnabled = ref(false) const backendModeEnabled = ref(false) const oidcOAuthEnabled = ref(false) const oidcOAuthProviderName = ref('OIDC') @@ -267,6 +274,7 @@ onMounted(async () => { turnstileEnabled.value = settings.turnstile_enabled turnstileSiteKey.value = settings.turnstile_site_key || '' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled + wechatOAuthEnabled.value = settings.wechat_oauth_enabled backendModeEnabled.value = settings.backend_mode_enabled oidcOAuthEnabled.value = settings.oidc_oauth_enabled oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index a6cb6c12..55f8af6e 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -15,36 +15,99 @@
-
-

- {{ t('auth.oidc.invitationRequired', { providerName }) }} -

-
- +
+
+
+
+

+ Use {{ providerName }} profile details +

+

+ Choose whether to apply the nickname or avatar from {{ providerName }} to this + account. +

+
+ + + + +
- -

- {{ invitationError }} + + + +

@@ -81,7 +144,10 @@ import Icon from '@/components/icons/Icon.vue' import { useAuthStore, useAppStore } from '@/stores' import { completeOIDCOAuthRegistration, - getPublicSettings + exchangePendingOAuthCompletion, + getPublicSettings, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse } from '@/api/auth' const route = useRoute() @@ -95,12 +161,17 @@ const isProcessing = ref(true) const errorMessage = ref('') const needsInvitation = ref(false) -const pendingOAuthToken = ref('') const invitationCode = ref('') const isSubmitting = ref(false) const invitationError = ref('') const redirectTo = ref('/dashboard') const providerName = ref('OIDC') +const adoptionRequired = ref(false) +const suggestedDisplayName = ref('') +const suggestedAvatarUrl = ref('') +const adoptDisplayName = ref(true) +const adoptAvatar = ref(true) +const needsAdoptionConfirmation = ref(false) function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -129,6 +200,54 @@ async function loadProviderName() { } } +function currentAdoptionDecision(): OAuthAdoptionDecision { + return { + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value + } +} + +function applyAdoptionSuggestionState(completion: { + adoption_required?: boolean + suggested_display_name?: string + suggested_avatar_url?: string +}) { + adoptionRequired.value = completion.adoption_required === true + suggestedDisplayName.value = completion.suggested_display_name || '' + suggestedAvatarUrl.value = completion.suggested_avatar_url || '' + + if (!suggestedDisplayName.value) { + adoptDisplayName.value = false + } + if (!suggestedAvatarUrl.value) { + adoptAvatar.value = false + } +} + +function hasSuggestedProfile(completion: { + suggested_display_name?: string + suggested_avatar_url?: string +}): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + +async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { + if (!completion.access_token) { + throw new Error(t('auth.oidc.callbackMissingToken')) + } + + if (completion.refresh_token) { + localStorage.setItem('refresh_token', completion.refresh_token) + } + if (completion.expires_in) { + localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + } + + await authStore.setToken(completion.access_token) + appStore.showSuccess(t('auth.loginSuccess')) + await router.replace(redirect) +} + async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return @@ -136,8 +255,8 @@ async function handleSubmitInvitation() { isSubmitting.value = true try { const tokenData = await completeOIDCOAuthRegistration( - pendingOAuthToken.value, - invitationCode.value.trim() + invitationCode.value.trim(), + currentAdoptionDecision() ) if (tokenData.refresh_token) { localStorage.setItem('refresh_token', tokenData.refresh_token) @@ -157,63 +276,67 @@ async function handleSubmitInvitation() { } } +async function handleContinueLogin() { + isSubmitting.value = true + try { + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeLogin(completion, redirectTo.value) + } catch (e: unknown) { + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') + appStore.showError(errorMessage.value) + needsAdoptionConfirmation.value = false + } finally { + isSubmitting.value = false + } +} + onMounted(async () => { void loadProviderName() const params = parseFragmentParams() - const token = params.get('access_token') || '' - const refreshToken = params.get('refresh_token') || '' - const expiresInStr = params.get('expires_in') || '' - const redirect = sanitizeRedirectPath( - params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' - ) const error = params.get('error') const errorDesc = params.get('error_description') || params.get('error_message') || '' if (error) { - if (error === 'invitation_required') { - pendingOAuthToken.value = params.get('pending_oauth_token') || '' - redirectTo.value = sanitizeRedirectPath(params.get('redirect')) - if (!pendingOAuthToken.value) { - errorMessage.value = t('auth.oidc.invalidPendingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - needsInvitation.value = true - isProcessing.value = false - return - } errorMessage.value = errorDesc || error appStore.showError(errorMessage.value) isProcessing.value = false return } - if (!token) { - errorMessage.value = t('auth.oidc.callbackMissingToken') - appStore.showError(errorMessage.value) - isProcessing.value = false - return - } - try { - if (refreshToken) { - localStorage.setItem('refresh_token', refreshToken) + const completion = await exchangePendingOAuthCompletion() + const redirect = sanitizeRedirectPath( + completion.redirect || (route.query.redirect as string | undefined) || '/dashboard' + ) + applyAdoptionSuggestionState(completion) + redirectTo.value = redirect + + if (completion.error === 'invitation_required') { + needsInvitation.value = true + isProcessing.value = false + return } - if (expiresInStr) { - const expiresIn = parseInt(expiresInStr, 10) - if (!isNaN(expiresIn)) { - localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) - } + + if (adoptionRequired.value && hasSuggestedProfile(completion)) { + needsAdoptionConfirmation.value = true + isProcessing.value = false + return } - await authStore.setToken(token) - appStore.showSuccess(t('auth.loginSuccess')) - await router.replace(redirect) + await finalizeLogin(completion, redirect) } catch (e: unknown) { - const err = e as { message?: string; response?: { data?: { detail?: string } } } - errorMessage.value = err.response?.data?.detail || err.message || t('auth.loginFailed') + const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } + errorMessage.value = + err.response?.data?.detail || + err.response?.data?.message || + err.message || + t('auth.loginFailed') appStore.showError(errorMessage.value) isProcessing.value = false } diff --git a/frontend/src/views/auth/RegisterView.vue b/frontend/src/views/auth/RegisterView.vue index bc8b8dce..378f9d8a 100644 --- a/frontend/src/views/auth/RegisterView.vue +++ b/frontend/src/views/auth/RegisterView.vue @@ -11,12 +11,17 @@

-
+
+ (false) const turnstileSiteKey = ref('') const siteName = ref('Sub2API') const linuxdoOAuthEnabled = ref(false) +const wechatOAuthEnabled = ref(false) const oidcOAuthEnabled = ref(false) const oidcOAuthProviderName = ref('OIDC') const registrationEmailSuffixWhitelist = ref([]) @@ -397,6 +404,7 @@ onMounted(async () => { turnstileSiteKey.value = settings.turnstile_site_key || '' siteName.value = settings.site_name || 'Sub2API' linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled + wechatOAuthEnabled.value = settings.wechat_oauth_enabled oidcOAuthEnabled.value = settings.oidc_oauth_enabled oidcOAuthProviderName.value = settings.oidc_oauth_provider_name || 'OIDC' registrationEmailSuffixWhitelist.value = normalizeRegistrationEmailSuffixWhitelist( diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue new file mode 100644 index 00000000..407b395b --- /dev/null +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -0,0 +1,361 @@ + + + + + diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts new file mode 100644 index 00000000..60a40474 --- /dev/null +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -0,0 +1,180 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import LinuxDoCallbackView from '../LinuxDoCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeLinuxDoOAuthRegistration = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/auth', () => ({ + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) +})) + +describe('LinuxDoCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeLinuxDoOAuthRegistration.mockReset() + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + completeLinuxDoOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('LinuxDo Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[0].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeLinuxDoOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + }) +}) diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts new file mode 100644 index 00000000..299c0746 --- /dev/null +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -0,0 +1,191 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import OidcCallbackView from '../OidcCallbackView.vue' + +const replace = vi.fn() +const showSuccess = vi.fn() +const showError = vi.fn() +const setToken = vi.fn() +const exchangePendingOAuthCompletion = vi.fn() +const completeOIDCOAuthRegistration = vi.fn() +const getPublicSettings = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: {} + }), + useRouter: () => ({ + replace + }) +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (!params?.providerName) { + return key + } + return `${key}:${params.providerName}` + } + }) + } +}) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken + }), + useAppStore: () => ({ + showSuccess, + showError + }) +})) + +vi.mock('@/api/auth', () => ({ + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args) +})) + +describe('OidcCallbackView', () => { + beforeEach(() => { + replace.mockReset() + showSuccess.mockReset() + showError.mockReset() + setToken.mockReset() + exchangePendingOAuthCompletion.mockReset() + completeOIDCOAuthRegistration.mockReset() + getPublicSettings.mockReset() + getPublicSettings.mockResolvedValue({ + oidc_oauth_provider_name: 'ExampleID' + }) + }) + + it('does not send adoption decisions during the initial exchange', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true + }) + setToken.mockResolvedValue({}) + + mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(setToken).not.toHaveBeenCalled() + expect(replace).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + await checkboxes[0].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(2) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: false, + adoptAvatar: true + }) + expect(setToken).toHaveBeenCalledWith('access-token') + expect(replace).toHaveBeenCalledWith('/dashboard') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({ + error: 'invitation_required', + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + completeOIDCOAuthRegistration.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + token_type: 'Bearer' + }) + setToken.mockResolvedValue({}) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('OIDC Nick') + expect(exchangePendingOAuthCompletion).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletion).toHaveBeenCalledWith() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + + await checkboxes[1].setValue(false) + await wrapper.find('input[type="text"]').setValue('invite-code') + await wrapper.find('button').trigger('click') + + expect(completeOIDCOAuthRegistration).toHaveBeenCalledWith('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + }) +}) diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts new file mode 100644 index 00000000..a9e2ada2 --- /dev/null +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -0,0 +1,241 @@ +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' + +const { + postMock, + replaceMock, + setTokenMock, + showSuccessMock, + showErrorMock, + routeState, +} = vi.hoisted(() => ({ + postMock: vi.fn(), + replaceMock: vi.fn(), + setTokenMock: vi.fn(), + showSuccessMock: vi.fn(), + showErrorMock: vi.fn(), + routeState: { + query: {} as Record, + }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, + useRouter: () => ({ + replace: replaceMock, + }), +})) + +vi.mock('vue-i18n', () => ({ + createI18n: () => ({ + global: { + t: (key: string) => key, + }, + }), + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'auth.oidc.callbackTitle') { + return `Signing you in with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.callbackProcessing') { + return `Completing login with ${params?.providerName ?? ''}`.trim() + } + if (key === 'auth.oidc.invitationRequired') { + return `${params?.providerName ?? ''} invitation required`.trim() + } + if (key === 'auth.oidc.completeRegistration') { + return 'Complete registration' + } + if (key === 'auth.oidc.completing') { + return 'Completing' + } + if (key === 'auth.oidc.backToLogin') { + return 'Back to login' + } + if (key === 'auth.invitationCodePlaceholder') { + return 'Invitation code' + } + if (key === 'auth.loginSuccess') { + return 'Login success' + } + if (key === 'auth.loginFailed') { + return 'Login failed' + } + if (key === 'auth.oidc.callbackHint') { + return 'Callback hint' + } + if (key === 'auth.oidc.callbackMissingToken') { + return 'Missing login token' + } + if (key === 'auth.oidc.completeRegistrationFailed') { + return 'Complete registration failed' + } + return key + }, + }), +})) + +vi.mock('@/stores', () => ({ + useAuthStore: () => ({ + setToken: setTokenMock, + }), + useAppStore: () => ({ + showSuccess: showSuccessMock, + showError: showErrorMock, + }), +})) + +vi.mock('@/api/client', () => ({ + apiClient: { + post: postMock, + }, +})) + +describe('WechatCallbackView', () => { + beforeEach(() => { + postMock.mockReset() + replaceMock.mockReset() + setTokenMock.mockReset() + showSuccessMock.mockReset() + showErrorMock.mockReset() + routeState.query = {} + localStorage.clear() + }) + + it('does not send adoption decisions during the initial exchange', async () => { + postMock.mockResolvedValueOnce({ + data: { + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, + }, + }) + setTokenMock.mockResolvedValue({}) + + mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {}) + expect(postMock).toHaveBeenCalledTimes(1) + }) + + it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { + postMock + .mockResolvedValueOnce({ + data: { + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }, + }) + .mockResolvedValueOnce({ + data: { + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', + }, + }) + setTokenMock.mockResolvedValue({}) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + expect(setTokenMock).not.toHaveBeenCalled() + expect(replaceMock).not.toHaveBeenCalled() + + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[1].setValue(false) + + const buttons = wrapper.findAll('button') + expect(buttons).toHaveLength(1) + await buttons[0].trigger('click') + await flushPromises() + + expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {}) + expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', { + adopt_display_name: true, + adopt_avatar: false, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') + expect(replaceMock).toHaveBeenCalledWith('/dashboard') + expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + postMock + .mockResolvedValueOnce({ + data: { + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }, + }) + .mockResolvedValueOnce({ + data: { + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }, + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('WeChat Nick') + const checkboxes = wrapper.findAll('input[type="checkbox"]') + expect(checkboxes).toHaveLength(2) + await checkboxes[0].setValue(false) + await wrapper.get('input[type="text"]').setValue(' INVITE-CODE ') + await wrapper.get('button').trigger('click') + await flushPromises() + + expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', { + invitation_code: 'INVITE-CODE', + adopt_display_name: false, + adopt_avatar: true, + }) + expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') + expect(replaceMock).toHaveBeenCalledWith('/subscriptions') + }) +}) diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index e91df5da..838f3000 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -23,20 +23,7 @@ :order-type="paymentState.orderType" @done="onPaymentDone" @success="onPaymentSuccess" - /> - - @@ -265,7 +252,7 @@ diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index b6f6022d..e82ae229 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -4,11 +4,16 @@ class="border-b border-gray-100 bg-gradient-to-r from-primary-500/10 to-primary-600/5 px-6 py-5 dark:border-dark-700 dark:from-primary-500/20 dark:to-primary-600/10" >
-
- {{ user?.email?.charAt(0).toUpperCase() || 'U' }} + + {{ avatarInitial }}

@@ -41,18 +46,163 @@ {{ user.username }}

+ +
+
+ + {{ hint.text }} +
+
+ +
diff --git a/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts new file mode 100644 index 00000000..1c9531e3 --- /dev/null +++ b/frontend/src/components/user/profile/__tests__/ProfileIdentityBindingsSection.spec.ts @@ -0,0 +1,120 @@ +import { mount } from '@vue/test-utils' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import ProfileIdentityBindingsSection from '@/components/user/profile/ProfileIdentityBindingsSection.vue' +import type { User } from '@/types' + +const routeState = vi.hoisted(() => ({ + fullPath: '/profile', +})) + +const locationState = vi.hoisted(() => ({ + current: { href: 'http://localhost/profile' } as { href: string }, +})) + +vi.mock('vue-router', () => ({ + useRoute: () => routeState, +})) + +vi.mock('vue-i18n', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useI18n: () => ({ + t: (key: string, params?: Record) => { + if (key === 'profile.authBindings.title') return 'Connected sign-in methods' + if (key === 'profile.authBindings.description') return 'Manage bound providers' + if (key === 'profile.authBindings.status.bound') return 'Bound' + if (key === 'profile.authBindings.status.notBound') return 'Not bound' + if (key === 'profile.authBindings.providers.email') return 'Email' + if (key === 'profile.authBindings.providers.linuxdo') return 'LinuxDo' + if (key === 'profile.authBindings.providers.wechat') return 'WeChat' + if (key === 'profile.authBindings.providers.oidc') return params?.providerName || 'OIDC' + if (key === 'profile.authBindings.bindAction') return `Bind ${params?.providerName || ''}`.trim() + return key + }, + }), + } +}) + +function createUser(overrides: Partial = {}): User { + return { + id: 7, + username: 'alice', + email: 'alice@example.com', + role: 'user', + balance: 10, + concurrency: 2, + status: 'active', + allowed_groups: null, + balance_notify_enabled: true, + balance_notify_threshold: null, + balance_notify_extra_emails: [], + created_at: '2026-04-20T00:00:00Z', + updated_at: '2026-04-20T00:00:00Z', + ...overrides, + } +} + +describe('ProfileIdentityBindingsSection', () => { + beforeEach(() => { + routeState.fullPath = '/profile' + locationState.current = { href: 'http://localhost/profile' } + Object.defineProperty(window, 'location', { + configurable: true, + value: locationState.current, + }) + Object.defineProperty(window.navigator, 'userAgent', { + configurable: true, + value: 'Mozilla/5.0', + }) + }) + + afterEach(() => { + vi.unstubAllGlobals() + }) + + it('renders provider binding states and provider-specific bind actions', () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + props: { + user: createUser({ + auth_bindings: { + email: { bound: true }, + linuxdo: { bound: true }, + oidc: { bound: false }, + wechat: false, + }, + }), + linuxdoEnabled: true, + oidcEnabled: true, + oidcProviderName: 'ExampleID', + wechatEnabled: true, + }, + }) + + expect(wrapper.get('[data-testid="profile-binding-email-status"]').text()).toBe('Bound') + expect(wrapper.get('[data-testid="profile-binding-linuxdo-status"]').text()).toBe('Bound') + expect(wrapper.get('[data-testid="profile-binding-oidc-status"]').text()).toBe('Not bound') + expect(wrapper.get('[data-testid="profile-binding-oidc-action"]').text()).toBe( + 'Bind ExampleID' + ) + expect(wrapper.get('[data-testid="profile-binding-wechat-action"]').text()).toBe('Bind WeChat') + }) + + it('starts the WeChat bind flow for the current profile page', async () => { + const wrapper = mount(ProfileIdentityBindingsSection, { + props: { + user: createUser(), + linuxdoEnabled: false, + oidcEnabled: false, + wechatEnabled: true, + }, + }) + + await wrapper.get('[data-testid="profile-binding-wechat-action"]').trigger('click') + + expect(locationState.current.href).toContain('/api/v1/auth/oauth/wechat/start?') + expect(locationState.current.href).toContain('mode=open') + expect(locationState.current.href).toContain('intent=bind_current_user') + expect(locationState.current.href).toContain('redirect=%2Fprofile') + }) +}) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 684c196f..7d058a74 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -940,6 +940,26 @@ export default { maxEmailsReached: 'Maximum number of notification emails reached', unverified: 'Unverified', verified: 'Verified', + }, + authBindings: { + title: 'Connected Sign-In Methods', + description: 'View current bindings and connect another provider to this account.', + bindAction: 'Bind {providerName}', + bindSuccess: 'Account linked successfully', + status: { + bound: 'Bound', + notBound: 'Not bound', + }, + providers: { + email: 'Email', + linuxdo: 'LinuxDo', + oidc: '{providerName}', + wechat: 'WeChat', + }, + source: { + avatar: 'Avatar is currently synced from {providerName}', + username: 'Nickname is currently synced from {providerName}', + }, } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 2a4c69a5..6dd74334 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -944,6 +944,26 @@ export default { maxEmailsReached: '已达到通知邮箱数量上限', unverified: '未验证', verified: '已验证', + }, + authBindings: { + title: '登录方式绑定', + description: '查看当前绑定状态,并将更多第三方登录方式关联到这个账号。', + bindAction: '绑定 {providerName}', + bindSuccess: '账号绑定成功', + status: { + bound: '已绑定', + notBound: '未绑定', + }, + providers: { + email: '邮箱', + linuxdo: 'LinuxDo', + oidc: '{providerName}', + wechat: '微信', + }, + source: { + avatar: '头像当前来自 {providerName}', + username: '昵称当前来自 {providerName}', + }, } }, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 9c9722a9..a19d6c26 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -34,10 +34,47 @@ export interface NotifyEmailEntry { // ==================== User & Auth Types ==================== +export type UserAuthProvider = 'email' | 'linuxdo' | 'oidc' | 'wechat' + +export interface UserAuthBindingStatus { + bound?: boolean + provider?: UserAuthProvider | string + provider_key?: string | null + provider_subject?: string | null + issuer?: string | null + label?: string | null + provider_label?: string | null + metadata?: Record +} + +export interface UserProfileSourceContext { + provider?: UserAuthProvider | string + source?: string | null + label?: string | null + provider_label?: string | null +} + export interface User { id: number username: string email: string + avatar_url?: string | null + avatar_source?: string | UserProfileSourceContext | null + username_source?: string | UserProfileSourceContext | null + display_name_source?: string | UserProfileSourceContext | null + nickname_source?: string | UserProfileSourceContext | null + profile_sources?: { + avatar?: string | UserProfileSourceContext | null + username?: string | UserProfileSourceContext | null + display_name?: string | UserProfileSourceContext | null + nickname?: string | UserProfileSourceContext | null + } + auth_bindings?: Partial> + identity_bindings?: Partial> + email_bound?: boolean + linuxdo_bound?: boolean + oidc_bound?: boolean + wechat_bound?: boolean role: 'admin' | 'user' // User role for authorization balance: number // User balance for API usage concurrency: number // Allowed concurrent requests diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 0a125def..6dc8f242 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -136,6 +136,9 @@ import { useAuthStore, useAppStore } from '@/stores' import { completeLinuxDoOAuthRegistration, exchangePendingOAuthCompletion, + getOAuthCompletionKind, + isOAuthLoginCompletion, + persistOAuthTokenContext, type OAuthAdoptionDecision, type PendingOAuthExchangeResponse } from '@/api/auth' @@ -162,6 +165,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -209,18 +213,19 @@ function hasSuggestedProfile(completion: { return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { - throw new Error(t('auth.linuxdo.callbackMissingToken')) +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + if (!isOAuthLoginCompletion(completion)) { + throw new Error(t('auth.linuxdo.callbackMissingToken')) } + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) @@ -236,12 +241,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -258,7 +258,7 @@ async function handleContinueLogin() { isSubmitting.value = true try { const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) - await finalizeLogin(completion, redirectTo.value) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -305,7 +305,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index 55f8af6e..83304226 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -145,7 +145,10 @@ import { useAuthStore, useAppStore } from '@/stores' import { completeOIDCOAuthRegistration, exchangePendingOAuthCompletion, + getOAuthCompletionKind, getPublicSettings, + isOAuthLoginCompletion, + persistOAuthTokenContext, type OAuthAdoptionDecision, type PendingOAuthExchangeResponse } from '@/api/auth' @@ -172,6 +175,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') function parseFragmentParams(): URLSearchParams { const raw = typeof window !== 'undefined' ? window.location.hash : '' @@ -231,18 +235,19 @@ function hasSuggestedProfile(completion: { return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { - throw new Error(t('auth.oidc.callbackMissingToken')) +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + if (!isOAuthLoginCompletion(completion)) { + throw new Error(t('auth.oidc.callbackMissingToken')) } + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) @@ -258,12 +263,7 @@ async function handleSubmitInvitation() { invitationCode.value.trim(), currentAdoptionDecision() ) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -280,7 +280,7 @@ async function handleContinueLogin() { isSubmitting.value = true try { const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) - await finalizeLogin(completion, redirectTo.value) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -329,7 +329,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/WechatCallbackView.vue b/frontend/src/views/auth/WechatCallbackView.vue index 407b395b..ac0ede4c 100644 --- a/frontend/src/views/auth/WechatCallbackView.vue +++ b/frontend/src/views/auth/WechatCallbackView.vue @@ -140,27 +140,16 @@ import { useRoute, useRouter } from 'vue-router' import { useI18n } from 'vue-i18n' import { AuthLayout } from '@/components/layout' import Icon from '@/components/icons/Icon.vue' -import { apiClient } from '@/api/client' import { useAuthStore, useAppStore } from '@/stores' - -interface OAuthTokenResponse { - access_token: string - refresh_token: string - expires_in: number - token_type: string -} - -interface PendingOAuthExchangeResponse { - access_token?: string - refresh_token?: string - expires_in?: number - token_type?: string - redirect?: string - error?: string - adoption_required?: boolean - suggested_display_name?: string - suggested_avatar_url?: string -} +import { + completeWeChatOAuthRegistration, + exchangePendingOAuthCompletion, + getOAuthCompletionKind, + isOAuthLoginCompletion, + persistOAuthTokenContext, + type OAuthAdoptionDecision, + type PendingOAuthExchangeResponse +} from '@/api/auth' const route = useRoute() const router = useRouter() @@ -182,6 +171,7 @@ const suggestedAvatarUrl = ref('') const adoptDisplayName = ref(true) const adoptAvatar = ref(true) const needsAdoptionConfirmation = ref(false) +const bindSuccessMessage = t('profile.authBindings.bindSuccess') const providerName = 'WeChat' @@ -200,10 +190,10 @@ function sanitizeRedirectPath(path: string | null | undefined): string { return path } -function currentAdoptionDecision(): Record { +function currentAdoptionDecision(): OAuthAdoptionDecision { return { - adopt_display_name: adoptDisplayName.value, - adopt_avatar: adoptAvatar.value, + adoptDisplayName: adoptDisplayName.value, + adoptAvatar: adoptAvatar.value } } @@ -224,49 +214,35 @@ function hasSuggestedProfile(completion: PendingOAuthExchangeResponse): boolean return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) } -async function exchangePendingOAuthCompletion(): Promise { - const { data } = await apiClient.post('/auth/oauth/pending/exchange', {}) - return data -} - -async function finalizeLogin(completion: PendingOAuthExchangeResponse, redirect: string) { - if (!completion.access_token) { - throw new Error(t('auth.oidc.callbackMissingToken')) +async function finalizeCompletion(completion: PendingOAuthExchangeResponse, redirect: string) { + if (getOAuthCompletionKind(completion) === 'bind') { + const bindRedirect = sanitizeRedirectPath(completion.redirect || '/profile') + appStore.showSuccess(bindSuccessMessage) + await router.replace(bindRedirect) + return } - if (completion.refresh_token) { - localStorage.setItem('refresh_token', completion.refresh_token) - } - if (completion.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + completion.expires_in * 1000)) + if (!isOAuthLoginCompletion(completion)) { + throw new Error(t('auth.oidc.callbackMissingToken')) } + persistOAuthTokenContext(completion) await authStore.setToken(completion.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) } -async function completeWeChatOAuthRegistration(invitation: string): Promise { - const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', { - invitation_code: invitation, - ...currentAdoptionDecision(), - }) - return data -} - async function handleSubmitInvitation() { invitationError.value = '' if (!invitationCode.value.trim()) return isSubmitting.value = true try { - const tokenData = await completeWeChatOAuthRegistration(invitationCode.value.trim()) - if (tokenData.refresh_token) { - localStorage.setItem('refresh_token', tokenData.refresh_token) - } - if (tokenData.expires_in) { - localStorage.setItem('token_expires_at', String(Date.now() + tokenData.expires_in * 1000)) - } + const tokenData = await completeWeChatOAuthRegistration( + invitationCode.value.trim(), + currentAdoptionDecision() + ) + persistOAuthTokenContext(tokenData) await authStore.setToken(tokenData.access_token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirectTo.value) @@ -282,11 +258,8 @@ async function handleSubmitInvitation() { async function handleContinueLogin() { isSubmitting.value = true try { - const { data } = await apiClient.post( - '/auth/oauth/pending/exchange', - currentAdoptionDecision() - ) - await finalizeLogin(data, redirectTo.value) + const completion = await exchangePendingOAuthCompletion(currentAdoptionDecision()) + await finalizeCompletion(completion, redirectTo.value) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = @@ -333,7 +306,7 @@ onMounted(async () => { return } - await finalizeLogin(completion, redirect) + await finalizeCompletion(completion, redirect) } catch (e: unknown) { const err = e as { message?: string; response?: { data?: { detail?: string; message?: string } } } errorMessage.value = diff --git a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts index 60a40474..7ffdcd19 100644 --- a/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/LinuxDoCallbackView.spec.ts @@ -39,10 +39,14 @@ vi.mock('@/stores', () => ({ }) })) -vi.mock('@/api/auth', () => ({ - exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), - completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeLinuxDoOAuthRegistration: (...args: any[]) => completeLinuxDoOAuthRegistration(...args) + } +}) describe('LinuxDoCallbackView', () => { beforeEach(() => { @@ -132,6 +136,64 @@ describe('LinuxDoCallbackView', () => { expect(replace).toHaveBeenCalledWith('/dashboard') }) + it('treats a completion without token as bind success and returns to profile', async () => { + exchangePendingOAuthCompletion.mockResolvedValue({}) + + mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'LinuxDo Nick', + suggested_avatar_url: 'https://cdn.example/linuxdo.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile/security' + }) + + const wrapper = mount(LinuxDoCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile/security') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'invitation_required', diff --git a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts index 299c0746..f8de79f2 100644 --- a/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/OidcCallbackView.spec.ts @@ -45,11 +45,15 @@ vi.mock('@/stores', () => ({ }) })) -vi.mock('@/api/auth', () => ({ - exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), - completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), - getPublicSettings: (...args: any[]) => getPublicSettings(...args) -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletion(...args), + completeOIDCOAuthRegistration: (...args: any[]) => completeOIDCOAuthRegistration(...args), + getPublicSettings: (...args: any[]) => getPublicSettings(...args) + } +}) describe('OidcCallbackView', () => { beforeEach(() => { @@ -143,6 +147,43 @@ describe('OidcCallbackView', () => { expect(replace).toHaveBeenCalledWith('/dashboard') }) + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletion + .mockResolvedValueOnce({ + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'OIDC Nick', + suggested_avatar_url: 'https://cdn.example/oidc.png' + }) + .mockResolvedValueOnce({ + redirect: '/profile' + }) + + const wrapper = mount(OidcCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false + } + } + }) + + await flushPromises() + + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletion).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true + }) + expect(setToken).not.toHaveBeenCalled() + expect(showSuccess).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replace).toHaveBeenCalledWith('/profile') + }) + it('renders adoption choices for invitation flow and submits the selected values', async () => { exchangePendingOAuthCompletion.mockResolvedValue({ error: 'invitation_required', diff --git a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts index a9e2ada2..896bf15d 100644 --- a/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts +++ b/frontend/src/views/auth/__tests__/WechatCallbackView.spec.ts @@ -3,14 +3,16 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import WechatCallbackView from '@/views/auth/WechatCallbackView.vue' const { - postMock, + exchangePendingOAuthCompletionMock, + completeWeChatOAuthRegistrationMock, replaceMock, setTokenMock, showSuccessMock, showErrorMock, routeState, } = vi.hoisted(() => ({ - postMock: vi.fn(), + exchangePendingOAuthCompletionMock: vi.fn(), + completeWeChatOAuthRegistrationMock: vi.fn(), replaceMock: vi.fn(), setTokenMock: vi.fn(), showSuccessMock: vi.fn(), @@ -86,15 +88,19 @@ vi.mock('@/stores', () => ({ }), })) -vi.mock('@/api/client', () => ({ - apiClient: { - post: postMock, - }, -})) +vi.mock('@/api/auth', async () => { + const actual = await vi.importActual('@/api/auth') + return { + ...actual, + exchangePendingOAuthCompletion: (...args: any[]) => exchangePendingOAuthCompletionMock(...args), + completeWeChatOAuthRegistration: (...args: any[]) => completeWeChatOAuthRegistrationMock(...args), + } +}) describe('WechatCallbackView', () => { beforeEach(() => { - postMock.mockReset() + exchangePendingOAuthCompletionMock.mockReset() + completeWeChatOAuthRegistrationMock.mockReset() replaceMock.mockReset() setTokenMock.mockReset() showSuccessMock.mockReset() @@ -104,14 +110,12 @@ describe('WechatCallbackView', () => { }) it('does not send adoption decisions during the initial exchange', async () => { - postMock.mockResolvedValueOnce({ - data: { - access_token: 'access-token', - refresh_token: 'refresh-token', - expires_in: 3600, - redirect: '/dashboard', - adoption_required: true, - }, + exchangePendingOAuthCompletionMock.mockResolvedValue({ + access_token: 'access-token', + refresh_token: 'refresh-token', + expires_in: 3600, + redirect: '/dashboard', + adoption_required: true, }) setTokenMock.mockResolvedValue({}) @@ -128,28 +132,24 @@ describe('WechatCallbackView', () => { await flushPromises() - expect(postMock).toHaveBeenCalledWith('/auth/oauth/pending/exchange', {}) - expect(postMock).toHaveBeenCalledTimes(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledWith() + expect(exchangePendingOAuthCompletionMock).toHaveBeenCalledTimes(1) }) it('waits for explicit adoption confirmation before finishing a non-invitation login', async () => { - postMock + exchangePendingOAuthCompletionMock .mockResolvedValueOnce({ - data: { - redirect: '/dashboard', - adoption_required: true, - suggested_display_name: 'WeChat Nick', - suggested_avatar_url: 'https://cdn.example/wechat.png', - }, + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', }) .mockResolvedValueOnce({ - data: { - access_token: 'wechat-access-token', - refresh_token: 'wechat-refresh-token', - expires_in: 3600, - token_type: 'Bearer', - redirect: '/dashboard', - }, + access_token: 'wechat-access-token', + refresh_token: 'wechat-refresh-token', + expires_in: 3600, + token_type: 'Bearer', + redirect: '/dashboard', }) setTokenMock.mockResolvedValue({}) @@ -179,34 +179,26 @@ describe('WechatCallbackView', () => { await buttons[0].trigger('click') await flushPromises() - expect(postMock).toHaveBeenNthCalledWith(1, '/auth/oauth/pending/exchange', {}) - expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/pending/exchange', { - adopt_display_name: true, - adopt_avatar: false, + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(1) + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: false, }) expect(setTokenMock).toHaveBeenCalledWith('wechat-access-token') expect(replaceMock).toHaveBeenCalledWith('/dashboard') expect(localStorage.getItem('refresh_token')).toBe('wechat-refresh-token') }) - it('renders adoption choices for invitation flow and submits the selected values', async () => { - postMock + it('supports bind completion after adoption confirmation', async () => { + exchangePendingOAuthCompletionMock .mockResolvedValueOnce({ - data: { - error: 'invitation_required', - redirect: '/subscriptions', - adoption_required: true, - suggested_display_name: 'WeChat Nick', - suggested_avatar_url: 'https://cdn.example/wechat.png', - }, + redirect: '/dashboard', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', }) .mockResolvedValueOnce({ - data: { - access_token: 'wechat-invite-token', - refresh_token: 'wechat-invite-refresh', - expires_in: 600, - token_type: 'Bearer', - }, + redirect: '/profile/connections', }) const wrapper = mount(WechatCallbackView, { @@ -222,6 +214,46 @@ describe('WechatCallbackView', () => { await flushPromises() + await wrapper.findAll('button')[0].trigger('click') + await flushPromises() + + expect(exchangePendingOAuthCompletionMock).toHaveBeenNthCalledWith(2, { + adoptDisplayName: true, + adoptAvatar: true, + }) + expect(setTokenMock).not.toHaveBeenCalled() + expect(showSuccessMock).toHaveBeenCalledWith('profile.authBindings.bindSuccess') + expect(replaceMock).toHaveBeenCalledWith('/profile/connections') + }) + + it('renders adoption choices for invitation flow and submits the selected values', async () => { + exchangePendingOAuthCompletionMock.mockResolvedValue({ + error: 'invitation_required', + redirect: '/subscriptions', + adoption_required: true, + suggested_display_name: 'WeChat Nick', + suggested_avatar_url: 'https://cdn.example/wechat.png', + }) + completeWeChatOAuthRegistrationMock.mockResolvedValue({ + access_token: 'wechat-invite-token', + refresh_token: 'wechat-invite-refresh', + expires_in: 600, + token_type: 'Bearer', + }) + + const wrapper = mount(WechatCallbackView, { + global: { + stubs: { + AuthLayout: { template: '
' }, + Icon: true, + RouterLink: { template: '' }, + transition: false, + }, + }, + }) + + await flushPromises() + expect(wrapper.text()).toContain('WeChat Nick') const checkboxes = wrapper.findAll('input[type="checkbox"]') expect(checkboxes).toHaveLength(2) @@ -230,10 +262,9 @@ describe('WechatCallbackView', () => { await wrapper.get('button').trigger('click') await flushPromises() - expect(postMock).toHaveBeenNthCalledWith(2, '/auth/oauth/wechat/complete-registration', { - invitation_code: 'INVITE-CODE', - adopt_display_name: false, - adopt_avatar: true, + expect(completeWeChatOAuthRegistrationMock).toHaveBeenCalledWith('INVITE-CODE', { + adoptDisplayName: false, + adoptAvatar: true, }) expect(setTokenMock).toHaveBeenCalledWith('wechat-invite-token') expect(replaceMock).toHaveBeenCalledWith('/subscriptions') diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue index e7418ebb..f7418be9 100644 --- a/frontend/src/views/user/ProfileView.vue +++ b/frontend/src/views/user/ProfileView.vue @@ -2,18 +2,53 @@
- - - + + +
- -
+ + + +
-
-

{{ t('common.contactSupport') }}

{{ contactInfo }}

+
+ +
+
+

+ {{ t('common.contactSupport') }} +

+

{{ contactInfo }}

+
+ + +
@@ -29,26 +65,78 @@ \ No newline at end of file +const formatCurrency = (value: number) => `$${value.toFixed(2)}` + -- GitLab From 4e0e69154649def4f4149054e09ff03fc8a3e50a Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 18:39:53 +0800 Subject: [PATCH 037/265] feat: apply auth source signup defaults --- .../service/admin_service_apikey_test.go | 3 + .../service/admin_service_delete_test.go | 51 ++++-- backend/internal/service/auth_service.go | 117 +++++++++---- .../service/auth_service_register_test.go | 159 +++++++++++++++++- 4 files changed, 283 insertions(+), 47 deletions(-) diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index b802a9c2..487fb5f1 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -79,6 +79,9 @@ func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *s } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 323286b0..ac1d8ee7 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -13,15 +13,18 @@ import ( ) type userRepoStub struct { - user *User - getErr error - createErr error - deleteErr error - exists bool - existsErr error - nextID int64 - created []*User - deletedIDs []int64 + user *User + getErr error + createErr error + deleteErr error + exists bool + existsErr error + nextID int64 + created []*User + updated []*User + deletedIDs []int64 + usersByEmail map[string]*User + getByEmailErr error } func (s *userRepoStub) Create(ctx context.Context, user *User) error { @@ -32,6 +35,11 @@ func (s *userRepoStub) Create(ctx context.Context, user *User) error { user.ID = s.nextID } s.created = append(s.created, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user return nil } @@ -46,7 +54,18 @@ func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { } func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) { - panic("unexpected GetByEmail call") + if s.getByEmailErr != nil { + return nil, s.getByEmailErr + } + if s.usersByEmail != nil { + if user, ok := s.usersByEmail[email]; ok { + return user, nil + } + } + if s.user != nil && s.user.Email == email { + return s.user, nil + } + return nil, ErrUserNotFound } func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { @@ -54,7 +73,13 @@ func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) { } func (s *userRepoStub) Update(ctx context.Context, user *User) error { - panic("unexpected Update call") + s.updated = append(s.updated, user) + if s.usersByEmail == nil { + s.usersByEmail = make(map[string]*User) + } + s.usersByEmail[user.Email] = user + s.user = user + return nil } func (s *userRepoStub) Delete(ctx context.Context, id int64) error { @@ -113,6 +138,10 @@ func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64 panic("unexpected AddGroupToAllowedGroups call") } +func (s *userRepoStub) ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) { + panic("unexpected ListUserAuthIdentities call") +} + func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { panic("unexpected UpdateTotpSecret call") } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 962009ce..40753139 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -78,6 +78,12 @@ type DefaultSubscriptionAssigner interface { AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) } +type signupGrantPlan struct { + Balance float64 + Concurrency int + Subscriptions []DefaultSubscriptionSetting +} + // NewAuthService 创建认证服务实例 func NewAuthService( entClient *dbent.Client, @@ -187,21 +193,15 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, fmt.Errorf("hash password: %w", err) } - // 获取默认配置 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + grantPlan := s.resolveSignupGrantPlan(ctx, "email") // 创建用户 user := &User{ Email: email, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -214,7 +214,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrServiceUnavailable } s.postAuthUserBootstrap(ctx, user, "email", true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { @@ -479,21 +479,16 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return "", nil, fmt.Errorf("hash password: %w", err) } - // 新用户默认值。 - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -511,8 +506,8 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -596,20 +591,16 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, fmt.Errorf("hash password: %w", err) } - defaultBalance := s.cfg.Default.UserBalance - defaultConcurrency := s.cfg.Default.UserConcurrency - if s.settingService != nil { - defaultBalance = s.settingService.GetDefaultBalance(ctx) - defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) - } + signupSource := inferLegacySignupSource(email) + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) newUser := &User{ Email: email, Username: username, PasswordHash: hashedPassword, Role: RoleUser, - Balance: defaultBalance, - Concurrency: defaultConcurrency, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, Status: StatusActive, } @@ -642,8 +633,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, ErrServiceUnavailable } user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") } } else { if err := s.userRepo.Create(ctx, newUser); err != nil { @@ -659,8 +650,8 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser - s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true) - s.assignDefaultSubscriptions(ctx, user.ID) + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { return nil, nil, ErrInvitationCodeInvalid @@ -694,22 +685,78 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { + if s.settingService == nil { + return + } + s.assignSubscriptions(ctx, userID, s.settingService.GetDefaultSubscriptions(ctx), "auto assigned by default user subscriptions setting") +} + +func (s *AuthService) assignSubscriptions(ctx context.Context, userID int64, items []DefaultSubscriptionSetting, notes string) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return } - items := s.settingService.GetDefaultSubscriptions(ctx) for _, item := range items { if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ UserID: userID, GroupID: item.GroupID, ValidityDays: item.ValidityDays, - Notes: "auto assigned by default user subscriptions setting", + Notes: notes, }); err != nil { logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) } } } +func (s *AuthService) resolveSignupGrantPlan(ctx context.Context, signupSource string) signupGrantPlan { + plan := signupGrantPlan{} + if s != nil && s.cfg != nil { + plan.Balance = s.cfg.Default.UserBalance + plan.Concurrency = s.cfg.Default.UserConcurrency + } + if s == nil || s.settingService == nil { + return plan + } + + plan.Balance = s.settingService.GetDefaultBalance(ctx) + plan.Concurrency = s.settingService.GetDefaultConcurrency(ctx) + plan.Subscriptions = s.settingService.GetDefaultSubscriptions(ctx) + + defaults, err := s.settingService.GetAuthSourceDefaultSettings(ctx) + if err != nil { + logger.LegacyPrintf("service.auth", "[Auth] Failed to load auth source signup defaults for %s: %v", signupSource, err) + return plan + } + + providerDefaults, ok := authSourceSignupSettings(defaults, signupSource) + if !ok || !providerDefaults.GrantOnSignup { + return plan + } + + plan.Balance = providerDefaults.Balance + plan.Concurrency = providerDefaults.Concurrency + plan.Subscriptions = providerDefaults.Subscriptions + return plan +} + +func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource string) (ProviderDefaultGrantSettings, bool) { + if defaults == nil { + return ProviderDefaultGrantSettings{}, false + } + + switch strings.ToLower(strings.TrimSpace(signupSource)) { + case "email": + return defaults.Email, true + case "linuxdo": + return defaults.LinuxDo, true + case "oidc": + return defaults.OIDC, true + case "wechat": + return defaults.WeChat, true + default: + return ProviderDefaultGrantSettings{}, false + } +} + func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) { if user == nil || user.ID <= 0 { return diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 103bafe7..901b3db3 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -37,7 +37,16 @@ func (s *settingRepoStub) Set(ctx context.Context, key, value string) error { } func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { - panic("unexpected GetMultiple call") + if s.err != nil { + return nil, s.err + } + result := make(map[string]string, len(keys)) + for _, key := range keys { + if v, ok := s.values[key]; ok { + result[key] = v + } + } + return result, nil } func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { @@ -62,6 +71,8 @@ type defaultSubscriptionAssignerStub struct { err error } +type refreshTokenCacheStub struct{} + func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { if input != nil { s.calls = append(s.calls, *input) @@ -72,6 +83,46 @@ func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.C return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil } +func (s *refreshTokenCacheStub) StoreRefreshToken(context.Context, string, *RefreshTokenData, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetRefreshToken(context.Context, string) (*RefreshTokenData, error) { + return nil, ErrRefreshTokenNotFound +} + +func (s *refreshTokenCacheStub) DeleteRefreshToken(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteUserRefreshTokens(context.Context, int64) error { + return nil +} + +func (s *refreshTokenCacheStub) DeleteTokenFamily(context.Context, string) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToUserTokenSet(context.Context, int64, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) AddToFamilyTokenSet(context.Context, string, string, time.Duration) error { + return nil +} + +func (s *refreshTokenCacheStub) GetUserTokenHashes(context.Context, int64) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) GetFamilyTokenHashes(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *refreshTokenCacheStub) IsTokenInFamily(context.Context, string, string) (bool, error) { + return false, nil +} + func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -484,3 +535,109 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { require.Equal(t, int64(12), assigner.calls[1].GroupID) require.Equal(t, 7, assigner.calls[1].ValidityDays) } + +func TestAuthService_Register_UsesEmailAuthSourceDefaultsWhenGrantEnabled(t *testing.T) { + repo := &userRepoStub{nextID: 52} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":91,"validity_days":3}]`, + SettingKeyAuthSourceDefaultEmailBalance: "12.5", + SettingKeyAuthSourceDefaultEmailConcurrency: "7", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-defaults@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 12.5, user.Balance) + require.Equal(t, 7, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(11), assigner.calls[0].GroupID) + require.Equal(t, 30, assigner.calls[0].ValidityDays) +} + +func TestAuthService_Register_GrantOnSignupFalseFallsBackToGlobalDefaults(t *testing.T) { + repo := &userRepoStub{nextID: 53} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":31,"validity_days":5}]`, + SettingKeyAuthSourceDefaultEmailBalance: "99", + SettingKeyAuthSourceDefaultEmailConcurrency: "88", + SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":32,"validity_days":9}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + }, nil) + service.defaultSubAssigner = assigner + + _, user, err := service.Register(context.Background(), "email-global@test.com", "password") + require.NoError(t, err) + require.NotNil(t, user) + require.Equal(t, 3.5, user.Balance) + require.Equal(t, 2, user.Concurrency) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(31), assigner.calls[0].GroupID) + require.Equal(t, 5, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefaultsOnSignup(t *testing.T) { + repo := &userRepoStub{nextID: 61} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":81,"validity_days":1}]`, + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.NotNil(t, user) + require.Equal(t, int64(61), user.ID) + require.Equal(t, 21.75, user.Balance) + require.Equal(t, 9, user.Concurrency) + require.Len(t, repo.created, 1) + require.Len(t, assigner.calls, 1) + require.Equal(t, int64(22), assigner.calls[0].GroupID) + require.Equal(t, 14, assigner.calls[0].ValidityDays) +} + +func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantAgain(t *testing.T) { + existing := &User{ + ID: 88, + Email: "linuxdo-123@linuxdo-connect.invalid", + Username: "existing-linuxdo", + Role: RoleUser, + Status: StatusActive, + Balance: 4, + Concurrency: 1, + TokenVersion: 2, + } + repo := &userRepoStub{user: existing} + assigner := &defaultSubscriptionAssignerStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultLinuxDoBalance: "21.75", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "9", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: `[{"group_id":22,"validity_days":14}]`, + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true", + }, nil) + service.defaultSubAssigner = assigner + service.refreshTokenCache = &refreshTokenCacheStub{} + + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "") + require.NoError(t, err) + require.NotNil(t, tokenPair) + require.Equal(t, existing.ID, user.ID) + require.Equal(t, 4.0, user.Balance) + require.Equal(t, 1, user.Concurrency) + require.Empty(t, repo.created) + require.Empty(t, assigner.calls) +} -- GitLab From 0353c3870f938f5d57d4060232dfe9e3bc4a01ff Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 18:40:34 +0800 Subject: [PATCH 038/265] test: update user service stubs for identity summaries --- .../service/billing_cache_service_singleflight_test.go | 4 ++++ backend/internal/service/user_service_test.go | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 4a8b8f03..2ebc2f04 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -86,6 +86,10 @@ func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, return &User{ID: id, Balance: s.balance}, nil } +func (s *balanceLoadUserRepoStub) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} + func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { cache := &billingCacheMissStub{} userRepo := &balanceLoadUserRepoStub{ diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 7d63bb36..89f0362e 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -100,9 +100,12 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } -func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { + return nil, nil +} +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } +func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { return nil } -- GitLab From d47580a144f083965b04e4612e4fd9ceb4779e35 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 18:42:28 +0800 Subject: [PATCH 039/265] test: pin email signup defaults in register tests --- backend/internal/service/auth_service_register_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 901b3db3..e0dce982 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -373,7 +373,8 @@ func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", + SettingKeyRegistrationEnabled: "true", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") @@ -520,8 +521,9 @@ func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { repo := &userRepoStub{nextID: 42} assigner := &defaultSubscriptionAssignerStub{} service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyRegistrationEnabled: "true", + SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", }, nil) service.defaultSubAssigner = assigner -- GitLab From 6a75bd77e3cbf11a12a624a474a979f684be5fb6 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 19:30:09 +0800 Subject: [PATCH 040/265] feat: add pending oauth email onboarding flow --- .../internal/handler/auth_linuxdo_oauth.go | 125 +++-- .../handler/auth_oauth_pending_flow.go | 426 +++++++++++++++++- .../handler/auth_oauth_pending_flow_test.go | 371 ++++++++++++++- backend/internal/handler/auth_oidc_oauth.go | 137 +++--- backend/internal/handler/auth_wechat_oauth.go | 33 ++ backend/internal/handler/dto/settings.go | 1 + backend/internal/handler/setting_handler.go | 1 + .../handler/setting_handler_public_test.go | 83 ++++ backend/internal/server/routes/auth.go | 48 ++ .../internal/service/auth_oauth_email_flow.go | 151 +++++++ backend/internal/service/setting_service.go | 2 + .../service/setting_service_public_test.go | 13 + backend/internal/service/settings_view.go | 1 + 13 files changed, 1273 insertions(+), 119 deletions(-) create mode 100644 backend/internal/handler/setting_handler_public_test.go create mode 100644 backend/internal/service/auth_oauth_email_flow.go diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 835e5fd8..c4ecb8fa 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -243,6 +243,18 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { if subject != "" { email = linuxDoSyntheticEmail(subject) } + identityKey := service.PendingAuthIdentityKey{ + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "suggested_display_name": displayName, + "suggested_avatar_url": avatarURL, + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) if err != nil { @@ -250,23 +262,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: oauthIntentBindCurrentUser, - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - TargetUserID: &targetUserID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: oauthIntentBindCurrentUser, + Identity: identityKey, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "redirect": redirectTo, }, @@ -278,27 +280,60 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityKey) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityKey, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: "login", + Identity: identityKey, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "error": "invitation_required", "redirect": redirectTo, @@ -316,23 +351,13 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "linuxdo", - ProviderKey: "linuxdo", - ProviderSubject: subject, - }, - TargetUserID: &user.ID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "suggested_display_name": displayName, - "suggested_avatar_url": avatarURL, - }, + Intent: "login", + Identity: identityKey, + TargetUserID: &user.ID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, "refresh_token": tokenPair.RefreshToken, diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 2d6c3714..99b9b406 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -46,6 +46,36 @@ type oauthAdoptionDecisionRequest struct { AdoptAvatar *bool `json:"adopt_avatar,omitempty"` } +type bindPendingOAuthLoginRequest struct { + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +type createPendingOAuthAccountRequest struct { + Email string `json:"email" binding:"required,email"` + VerifyCode string `json:"verify_code,omitempty"` + Password string `json:"password" binding:"required,min=6"` + InvitationCode string `json:"invitation_code,omitempty"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +func (r bindPendingOAuthLoginRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + +func (r createPendingOAuthAccountRequest) adoptionDecision() oauthAdoptionDecisionRequest { + return oauthAdoptionDecisionRequest{ + AdoptDisplayName: r.AdoptDisplayName, + AdoptAvatar: r.AdoptAvatar, + } +} + func (h *AuthHandler) pendingIdentityService() (*service.AuthPendingIdentityService, error) { if h == nil || h.authService == nil || h.authService.EntClient() == nil { return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") @@ -170,6 +200,36 @@ func readCompletionResponse(session map[string]any) (map[string]any, bool) { return result, true } +func clonePendingMap(values map[string]any) map[string]any { + if len(values) == 0 { + return map[string]any{} + } + cloned := make(map[string]any, len(values)) + for key, value := range values { + cloned[key] = value + } + return cloned +} + +func mergePendingCompletionResponse(session *dbent.PendingAuthSession, overrides map[string]any) map[string]any { + payload, _ := readCompletionResponse(session.LocalFlowState) + merged := clonePendingMap(payload) + if strings.TrimSpace(session.RedirectTo) != "" { + if _, exists := merged["redirect"]; !exists { + merged["redirect"] = session.RedirectTo + } + } + for key, value := range overrides { + if value == nil { + delete(merged, key) + continue + } + merged[key] = value + } + applySuggestedProfileToCompletionResponse(merged, session.UpstreamIdentityClaims) + return merged +} + func pendingSessionStringValue(values map[string]any, key string) string { if len(values) == 0 { return "" @@ -264,6 +324,89 @@ func (h *AuthHandler) entClient() *dbent.Client { return h.authService.EntClient() } +func (h *AuthHandler) isForceEmailOnThirdPartySignup(ctx context.Context) bool { + if h == nil || h.settingSvc == nil { + return false + } + defaults, err := h.settingSvc.GetAuthSourceDefaultSettings(ctx) + if err != nil || defaults == nil { + return false + } + return defaults.ForceEmailOnThirdPartySignup +} + +func (h *AuthHandler) findOAuthIdentityUser(ctx context.Context, identity service.PendingAuthIdentityKey) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(strings.TrimSpace(identity.ProviderSubject)), + ). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + + userEntity, err := client.User.Get(ctx, record.UserID) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_USER_LOOKUP_FAILED", "failed to load auth identity user").WithCause(err) + } + return userEntity, nil +} + +func (h *AuthHandler) createOAuthEmailRequiredPendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, +) error { + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + "step": "email_required", + "force_email_on_signup": true, + "email_binding_required": true, + "existing_account_bindable": true, + }, + }) +} + +func (h *AuthHandler) BindLinuxDoOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "linuxdo") } +func (h *AuthHandler) BindOIDCOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "oidc") } +func (h *AuthHandler) BindWeChatOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "wechat") } +func (h *AuthHandler) BindPendingOAuthLogin(c *gin.Context) { h.bindPendingOAuthLogin(c, "") } + +func (h *AuthHandler) CreateLinuxDoOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "linuxdo") +} + +func (h *AuthHandler) CreateOIDCOAuthAccount(c *gin.Context) { h.createPendingOAuthAccount(c, "oidc") } + +func (h *AuthHandler) CreateWeChatOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "wechat") +} + +func (h *AuthHandler) CreatePendingOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "") +} + func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( c *gin.Context, sessionID int64, @@ -313,6 +456,60 @@ func (h *AuthHandler) upsertPendingOAuthAdoptionDecision( return decision, nil } +func (h *AuthHandler) ensurePendingOAuthAdoptionDecision( + c *gin.Context, + sessionID int64, + req oauthAdoptionDecisionRequest, +) (*dbent.IdentityAdoptionDecision, error) { + decision, err := h.upsertPendingOAuthAdoptionDecision(c, sessionID, req) + if err != nil { + return nil, err + } + if decision != nil { + return decision, nil + } + + svc, err := h.pendingIdentityService() + if err != nil { + return nil, err + } + decision, err = svc.UpsertAdoptionDecision(c.Request.Context(), service.PendingIdentityAdoptionDecisionInput{ + PendingAuthSessionID: sessionID, + }) + if err != nil { + return nil, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_SAVE_FAILED", "failed to save oauth profile adoption decision").WithCause(err) + } + return decision, nil +} + +func updatePendingOAuthSessionProgress( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + intent string, + resolvedEmail string, + targetUserID *int64, + completionResponse map[string]any, +) (*dbent.PendingAuthSession, error) { + if client == nil || session == nil { + return nil, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") + } + + localFlowState := clonePendingMap(session.LocalFlowState) + localFlowState[oauthCompletionResponseKey] = clonePendingMap(completionResponse) + + update := client.PendingAuthSession.UpdateOneID(session.ID). + SetIntent(strings.TrimSpace(intent)). + SetResolvedEmail(strings.TrimSpace(resolvedEmail)). + SetLocalFlowState(localFlowState) + if targetUserID != nil && *targetUserID > 0 { + update = update.SetTargetUserID(*targetUserID) + } else { + update = update.ClearTargetUserID() + } + return update.Save(ctx) +} + func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession) (int64, error) { if session == nil { return 0, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth session is invalid") @@ -401,17 +598,18 @@ func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision return decision.AdoptDisplayName || decision.AdoptAvatar } -func applyPendingOAuthAdoption( +func applyPendingOAuthBinding( ctx context.Context, client *dbent.Client, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, + forceBind bool, ) error { - if client == nil || session == nil || decision == nil { + if client == nil || session == nil { return nil } - if !shouldBindPendingOAuthIdentity(session, decision) { + if !forceBind && !shouldBindPendingOAuthIdentity(session, decision) { return nil } @@ -427,11 +625,11 @@ func applyPendingOAuthAdoption( } adoptedDisplayName := "" - if decision.AdoptDisplayName { + if decision != nil && decision.AdoptDisplayName { adoptedDisplayName = normalizeAdoptedOAuthDisplayName(pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_display_name")) } adoptedAvatarURL := "" - if decision.AdoptAvatar { + if decision != nil && decision.AdoptAvatar { adoptedAvatarURL = pendingSessionStringValue(session.UpstreamIdentityClaims, "suggested_avatar_url") } @@ -441,7 +639,7 @@ func applyPendingOAuthAdoption( } defer func() { _ = tx.Rollback() }() - if decision.AdoptDisplayName && adoptedDisplayName != "" { + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { if err := tx.Client().User.UpdateOneID(targetUserID). SetUsername(adoptedDisplayName). Exec(ctx); err != nil { @@ -458,10 +656,10 @@ func applyPendingOAuthAdoption( for key, value := range session.UpstreamIdentityClaims { metadata[key] = value } - if decision.AdoptDisplayName && adoptedDisplayName != "" { + if decision != nil && decision.AdoptDisplayName && adoptedDisplayName != "" { metadata["display_name"] = adoptedDisplayName } - if decision.AdoptAvatar && adoptedAvatarURL != "" { + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" { metadata["avatar_url"] = adoptedAvatarURL } @@ -473,7 +671,7 @@ func applyPendingOAuthAdoption( return err } - if decision.IdentityID == nil || *decision.IdentityID != identity.ID { + if decision != nil && (decision.IdentityID == nil || *decision.IdentityID != identity.ID) { if _, err := tx.Client().IdentityAdoptionDecision.UpdateOneID(decision.ID). SetIdentityID(identity.ID). Save(ctx); err != nil { @@ -484,6 +682,16 @@ func applyPendingOAuthAdoption( return tx.Commit() } +func applyPendingOAuthAdoption( + ctx context.Context, + client *dbent.Client, + session *dbent.PendingAuthSession, + decision *dbent.IdentityAdoptionDecision, + overrideUserID *int64, +) error { + return applyPendingOAuthBinding(ctx, client, session, decision, overrideUserID, false) +} + func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream map[string]any) { if len(payload) == 0 || len(upstream) == 0 { return @@ -507,6 +715,206 @@ func applySuggestedProfileToCompletionResponse(payload map[string]any, upstream } } +func readPendingOAuthBrowserSession(c *gin.Context, h *AuthHandler) (*service.AuthPendingIdentityService, *dbent.PendingAuthSession, func(), error) { + secureCookie := isRequestHTTPS(c) + clearCookies := func() { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + } + + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil || strings.TrimSpace(sessionToken) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthSessionNotFound + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil || strings.TrimSpace(browserSessionKey) == "" { + clearCookies() + return nil, nil, clearCookies, service.ErrPendingAuthBrowserMismatch + } + + svc, err := h.pendingIdentityService() + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + session, err := svc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearCookies() + return nil, nil, clearCookies, err + } + + return svc, session, clearCookies, nil +} + +func buildPendingOAuthSessionStatusPayload(session *dbent.PendingAuthSession) gin.H { + payload := gin.H{ + "auth_result": "pending_session", + "provider": strings.TrimSpace(session.ProviderType), + "intent": strings.TrimSpace(session.Intent), + } + for key, value := range mergePendingCompletionResponse(session, nil) { + payload[key] = value + } + if email := strings.TrimSpace(session.ResolvedEmail); email != "" { + payload["email"] = email + } + return payload +} + +func writeOAuthTokenPairResponse(c *gin.Context, tokenPair *service.TokenPair) { + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { + var req bindPendingOAuthLoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + user, err := h.authService.ValidatePasswordCredentials(c.Request.Context(), strings.TrimSpace(req.Email), req.Password) + if err != nil { + response.ErrorFrom(c, err) + return + } + if session.TargetUserID != nil && *session.TargetUserID > 0 && user.ID != *session.TargetUserID { + response.ErrorFrom(c, infraerrors.Conflict("PENDING_AUTH_TARGET_USER_MISMATCH", "pending oauth session must be completed by the targeted user")) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), session, decision, &user.ID, true); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + response.InternalError(c, "Failed to generate token pair") + return + } + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + +func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) { + var req createPendingOAuthAccountRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + pendingSvc, session, clearCookies, err := readPendingOAuthBrowserSession(c, h) + if err != nil { + response.ErrorFrom(c, err) + return + } + if strings.TrimSpace(provider) != "" && !strings.EqualFold(strings.TrimSpace(session.ProviderType), provider) { + response.BadRequest(c, "Pending oauth session provider mismatch") + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + + email := strings.TrimSpace(strings.ToLower(req.Email)) + existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context()) + if err != nil && !dbent.IsNotFound(err) { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")) + return + } + if existingUser != nil { + completionResponse := mergePendingCompletionResponse(session, map[string]any{ + "step": "bind_login_required", + "email": email, + }) + session, err = updatePendingOAuthSessionProgress( + c.Request.Context(), + client, + session, + "adopt_existing_user_by_email", + email, + &existingUser.ID, + completionResponse, + ) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_SESSION_UPDATE_FAILED", "failed to update pending oauth session").WithCause(err)) + return + } + + if _, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()); err != nil { + response.ErrorFrom(c, err) + return + } + + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(session)) + return + } + + tokenPair, user, err := h.authService.RegisterOAuthEmailAccount( + c.Request.Context(), + email, + req.Password, + strings.TrimSpace(req.VerifyCode), + strings.TrimSpace(req.InvitationCode), + strings.TrimSpace(session.ProviderType), + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, req.adoptionDecision()) + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthBinding(c.Request.Context(), client, session, decision, &user.ID, true); err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) + return + } + + if _, err := pendingSvc.ConsumeBrowserSession(c.Request.Context(), session.SessionToken, session.BrowserSessionKey); err != nil { + clearCookies() + response.ErrorFrom(c, err) + return + } + + clearCookies() + writeOAuthTokenPairResponse(c, tokenPair) +} + // ExchangePendingOAuthCompletion redeems a pending OAuth browser session into a frontend-safe payload. // POST /api/v1/auth/oauth/pending/exchange func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 3afb4fb7..80338b8a 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -509,9 +509,305 @@ func TestExchangePendingOAuthCompletionInvitationRequiredFalseFalsePersistsDecis require.Nil(t, storedSession.ConsumedAt) } +func TestCreateOIDCOAuthAccountCreatesUserBindsIdentityAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "fresh@example.com", "246810") + ctx := context.Background() + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("create-account-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-create-123"). + SetBrowserSessionKey("create-account-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Fresh OIDC User", + "suggested_avatar_url": "https://cdn.example/fresh.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"fresh@example.com","verify_code":"246810","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("create-account-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + createdUser, err := client.User.Query().Where(dbuser.EmailEQ("fresh@example.com")).Only(ctx) + require.NoError(t, err) + require.Equal(t, service.StatusActive, createdUser.Status) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-create-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, createdUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-123"). + SetBrowserSessionKey("existing-email-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "pending_session", payload["auth_result"]) + require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) + require.Equal(t, "oidc", payload["provider"]) + require.Equal(t, "/dashboard", payload["redirect"]) + require.Equal(t, true, payload["adoption_required"]) + require.Equal(t, "Existing OIDC User", payload["suggested_display_name"]) + require.Equal(t, "https://cdn.example/existing.png", payload["suggested_avatar_url"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", storedSession.Intent) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) + require.Nil(t, storedSession.ConsumedAt) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-existing-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) +} + +func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetRedirectTo("/profile"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"secret-123","adopt_display_name":false,"adopt_avatar":false}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.NotEmpty(t, payload["access_token"]) + require.NotEmpty(t, payload["refresh_token"]) + require.Equal(t, "Bearer", payload["token_type"]) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, existingUser.ID, identity.UserID) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.ConsumedAt) +} + +func TestBindOIDCOAuthLoginRejectsInvalidPasswordWithoutConsumingSession(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + ctx := context.Background() + + passwordHash, err := handler.authService.HashPassword("secret-123") + require.NoError(t, err) + + existingUser, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner-user"). + SetPasswordHash(passwordHash). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("bind-login-invalid-password-session-token"). + SetIntent("adopt_existing_user_by_email"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-bind-invalid-123"). + SetTargetUserID(existingUser.ID). + SetResolvedEmail(existingUser.Email). + SetBrowserSessionKey("bind-login-invalid-password-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Bound OIDC User", + "suggested_avatar_url": "https://cdn.example/bound.png", + }). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","password":"wrong-password"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/bind-login", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("bind-login-invalid-password-browser-session-key")}) + ginCtx.Request = req + + handler.BindOIDCOAuthLogin(ginCtx) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) + payload := decodeJSONBody(t, recorder) + require.Equal(t, "INVALID_CREDENTIALS", payload["reason"]) + + identityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("oidc-bind-invalid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.Nil(t, storedSession.ConsumedAt) +} + func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, false, nil) +} + +func newOAuthPendingFlowTestHandlerWithEmailVerification( + t *testing.T, + invitationEnabled bool, + email string, + code string, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + cache := &oauthPendingFlowEmailCacheStub{ + verificationCodes: map[string]*service.VerificationCodeData{ + email: { + Code: code, + Attempts: 0, + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(15 * time.Minute), + }, + }, + } + return newOAuthPendingFlowTestHandlerWithOptions(t, invitationEnabled, true, cache) +} + +func newOAuthPendingFlowTestHandlerWithOptions( + t *testing.T, + invitationEnabled bool, + emailVerifyEnabled bool, + emailCache service.EmailCache, +) (*AuthHandler, *dbent.Client) { + t.Helper() + db, err := sql.Open("sqlite", "file:auth_oauth_pending_flow_handler?mode=memory&cache=shared") require.NoError(t, err) t.Cleanup(func() { _ = db.Close() }) @@ -538,9 +834,18 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth values: map[string]string{ service.SettingKeyRegistrationEnabled: "true", service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), }, }, cfg) userRepo := &oauthPendingFlowUserRepo{client: client} + var emailService *service.EmailService + if emailCache != nil { + emailService = service.NewEmailService(&oauthPendingFlowSettingRepoStub{ + values: map[string]string{ + service.SettingKeyEmailVerifyEnabled: boolSettingValue(emailVerifyEnabled), + }, + }, emailCache) + } authSvc := service.NewAuthService( client, userRepo, @@ -548,7 +853,7 @@ func newOAuthPendingFlowTestHandler(t *testing.T, invitationEnabled bool) (*Auth &oauthPendingFlowRefreshTokenCacheStub{}, cfg, settingSvc, - nil, + emailService, nil, nil, nil, @@ -622,6 +927,70 @@ func (s *oauthPendingFlowSettingRepoStub) Delete(context.Context, string) error type oauthPendingFlowRefreshTokenCacheStub struct{} +type oauthPendingFlowEmailCacheStub struct { + verificationCodes map[string]*service.VerificationCodeData +} + +func (s *oauthPendingFlowEmailCacheStub) GetVerificationCode(_ context.Context, email string) (*service.VerificationCodeData, error) { + if s == nil || s.verificationCodes == nil { + return nil, nil + } + return s.verificationCodes[email], nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetVerificationCode(_ context.Context, email string, data *service.VerificationCodeData, _ time.Duration) error { + if s.verificationCodes == nil { + s.verificationCodes = map[string]*service.VerificationCodeData{} + } + s.verificationCodes[email] = data + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteVerificationCode(_ context.Context, email string) error { + delete(s.verificationCodes, email) + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyVerifyCode(context.Context, string) (*service.VerificationCodeData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetNotifyVerifyCode(context.Context, string, *service.VerificationCodeData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeleteNotifyVerifyCode(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetPasswordResetToken(context.Context, string) (*service.PasswordResetTokenData, error) { + return nil, nil +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetToken(context.Context, string, *service.PasswordResetTokenData, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) DeletePasswordResetToken(context.Context, string) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IsPasswordResetEmailInCooldown(context.Context, string) bool { + return false +} + +func (s *oauthPendingFlowEmailCacheStub) SetPasswordResetEmailCooldown(context.Context, string, time.Duration) error { + return nil +} + +func (s *oauthPendingFlowEmailCacheStub) IncrNotifyCodeUserRate(context.Context, int64, time.Duration) (int64, error) { + return 0, nil +} + +func (s *oauthPendingFlowEmailCacheStub) GetNotifyCodeUserRate(context.Context, int64) (int64, error) { + return 0, nil +} + func (s *oauthPendingFlowRefreshTokenCacheStub) StoreRefreshToken(context.Context, string, *service.RefreshTokenData, time.Duration) error { return nil } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 0f79759e..909d6379 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -342,6 +342,21 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { idClaims.Name, oidcFallbackUsername(subject), ) + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: issuer, + ProviderSubject: subject, + } + upstreamClaims := map[string]any{ + "email": email, + "username": username, + "subject": subject, + "issuer": issuer, + "email_verified": emailVerified != nil && *emailVerified, + "provider_fallback": strings.TrimSpace(cfg.ProviderName), + "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), + "suggested_avatar_url": userInfoClaims.AvatarURL, + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, oidcOAuthBindUserCookieName) if err != nil { @@ -349,26 +364,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: oauthIntentBindCurrentUser, - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - TargetUserID: &targetUserID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: oauthIntentBindCurrentUser, + Identity: identityRef, + TargetUserID: &targetUserID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "redirect": redirectTo, }, @@ -380,30 +382,60 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identityRef, + TargetUserID: &user.ID, + ResolvedEmail: existingIdentityUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + "redirect": redirectTo, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + // 传入空邀请码;如果需要邀请码,服务层返回 ErrOAuthInvitationRequired tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if errors.Is(err, service.ErrOAuthInvitationRequired) { if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: "login", + Identity: identityRef, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "error": "invitation_required", "redirect": redirectTo, @@ -420,26 +452,13 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ - Intent: "login", - Identity: service.PendingAuthIdentityKey{ - ProviderType: "oidc", - ProviderKey: issuer, - ProviderSubject: subject, - }, - TargetUserID: &user.ID, - ResolvedEmail: email, - RedirectTo: redirectTo, - BrowserSessionKey: browserSessionKey, - UpstreamIdentityClaims: map[string]any{ - "email": email, - "username": username, - "subject": subject, - "issuer": issuer, - "email_verified": emailVerified != nil && *emailVerified, - "provider_fallback": strings.TrimSpace(cfg.ProviderName), - "suggested_display_name": firstNonEmpty(userInfoClaims.DisplayName, idClaims.Name, username), - "suggested_avatar_url": userInfoClaims.AvatarURL, - }, + Intent: "login", + Identity: identityRef, + TargetUserID: &user.ID, + ResolvedEmail: email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, CompletionResponse: map[string]any{ "access_token": tokenPair.AccessToken, "refresh_token": tokenPair.RefreshToken, diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 45ac6cad..6d37c799 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -214,6 +214,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { "suggested_display_name": strings.TrimSpace(userInfo.Nickname), "suggested_avatar_url": strings.TrimSpace(userInfo.HeadImgURL), } + identityRef := service.PendingAuthIdentityKey{ + ProviderType: "wechat", + ProviderKey: wechatOAuthProviderKey, + ProviderSubject: providerSubject, + } normalizedIntent := normalizeWeChatOAuthIntent(intent) if normalizedIntent == wechatOAuthIntentBind { @@ -232,6 +237,34 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { return } + existingIdentityUser, err := h.findOAuthIdentityUser(c.Request.Context(), identityRef) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if existingIdentityUser != nil { + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") + if err != nil { + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, existingIdentityUser.Email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, nil, &user.ID); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { + if err := h.createOAuthEmailRequiredPendingSession(c, identityRef, redirectTo, browserSessionKey, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, "") if err != nil { if err := h.createWeChatPendingSession(c, normalizedIntent, providerSubject, email, redirectTo, browserSessionKey, upstreamClaims, tokenPair, err, nil); err != nil { diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index f44b3e3b..637e317b 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -167,6 +167,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool `json:"registration_enabled"` EmailVerifyEnabled bool `json:"email_verify_enabled"` + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` RegistrationEmailSuffixWhitelist []string `json:"registration_email_suffix_whitelist"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index c7bc3e2a..9925f066 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -34,6 +34,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { response.Success(c, dto.PublicSettings{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup, RegistrationEmailSuffixWhitelist: settings.RegistrationEmailSuffixWhitelist, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go new file mode 100644 index 00000000..114c7245 --- /dev/null +++ b/backend/internal/handler/setting_handler_public_test.go @@ -0,0 +1,83 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerPublicRepoStub struct { + values map[string]string +} + +func (s *settingHandlerPublicRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + panic("unexpected Get call") +} + +func (s *settingHandlerPublicRepoStub) GetValue(ctx context.Context, key string) (string, error) { + panic("unexpected GetValue call") +} + +func (s *settingHandlerPublicRepoStub) Set(ctx context.Context, key, value string) error { + panic("unexpected Set call") +} + +func (s *settingHandlerPublicRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerPublicRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingHandlerPublicRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingHandlerPublicRepoStub) Delete(ctx context.Context, key string) error { + panic("unexpected Delete call") +} + +func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + gin.SetMode(gin.TestMode) + + repo := &settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + h := NewSettingHandler(service.NewSettingService(repo, &config.Config{}), "test-version") + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/settings/public", nil) + + h.GetPublicSettings(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.True(t, resp.Data.ForceEmailOnThirdPartySignup) +} diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 7a34834d..1f28e9c3 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -72,18 +72,54 @@ func RegisterAuthRoutes( }), h.Auth.ExchangePendingOAuthCompletion, ) + auth.POST("/oauth/pending/create-account", + rateLimiter.LimitWithOptions("oauth-pending-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreatePendingOAuthAccount, + ) + auth.POST("/oauth/pending/bind-login", + rateLimiter.LimitWithOptions("oauth-pending-bind-login", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindPendingOAuthLogin, + ) auth.POST("/oauth/linuxdo/complete-registration", rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteLinuxDoOAuthRegistration, ) + auth.POST("/oauth/linuxdo/bind-login", + rateLimiter.LimitWithOptions("oauth-linuxdo-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindLinuxDoOAuthLogin, + ) + auth.POST("/oauth/linuxdo/create-account", + rateLimiter.LimitWithOptions("oauth-linuxdo-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateLinuxDoOAuthAccount, + ) auth.POST("/oauth/wechat/complete-registration", rateLimiter.LimitWithOptions("oauth-wechat-complete", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.CompleteWeChatOAuthRegistration, ) + auth.POST("/oauth/wechat/bind-login", + rateLimiter.LimitWithOptions("oauth-wechat-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindWeChatOAuthLogin, + ) + auth.POST("/oauth/wechat/create-account", + rateLimiter.LimitWithOptions("oauth-wechat-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateWeChatOAuthAccount, + ) auth.GET("/oauth/oidc/start", h.Auth.OIDCOAuthStart) auth.GET("/oauth/oidc/callback", h.Auth.OIDCOAuthCallback) auth.POST("/oauth/oidc/complete-registration", @@ -92,6 +128,18 @@ func RegisterAuthRoutes( }), h.Auth.CompleteOIDCOAuthRegistration, ) + auth.POST("/oauth/oidc/bind-login", + rateLimiter.LimitWithOptions("oauth-oidc-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindOIDCOAuthLogin, + ) + auth.POST("/oauth/oidc/create-account", + rateLimiter.LimitWithOptions("oauth-oidc-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateOIDCOAuthAccount, + ) } // 公开设置(无需认证) diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go new file mode 100644 index 00000000..ca3403d4 --- /dev/null +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -0,0 +1,151 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" +) + +// VerifyOAuthEmailCode verifies the locally entered email verification code for +// third-party signup and binding flows. This is intentionally independent from +// the global registration email verification toggle. +func (s *AuthService) VerifyOAuthEmailCode(ctx context.Context, email, verifyCode string) error { + email = strings.TrimSpace(strings.ToLower(email)) + verifyCode = strings.TrimSpace(verifyCode) + + if email == "" { + return ErrEmailVerifyRequired + } + if verifyCode == "" { + return ErrEmailVerifyRequired + } + if s == nil || s.emailService == nil { + return ErrServiceUnavailable + } + return s.emailService.VerifyCode(ctx, email, verifyCode) +} + +// RegisterOAuthEmailAccount creates a local account from a third-party first +// login after the user has verified a local email address. +func (s *AuthService) RegisterOAuthEmailAccount( + ctx context.Context, + email string, + password string, + verifyCode string, + invitationCode string, + signupSource string, +) (*TokenPair, *User, error) { + if s == nil { + return nil, nil, ErrServiceUnavailable + } + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + email = strings.TrimSpace(strings.ToLower(email)) + if isReservedEmail(email) { + return nil, nil, ErrEmailReserved + } + if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + return nil, nil, err + } + if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil { + return nil, nil, err + } + + var invitationRedeemCode *RedeemCode + if s.settingService.IsInvitationCodeEnabled(ctx) { + if invitationCode == "" { + return nil, nil, ErrInvitationCodeRequired + } + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + return nil, nil, ErrInvitationCodeInvalid + } + invitationRedeemCode = redeemCode + } + + existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) + if err != nil { + return nil, nil, ErrServiceUnavailable + } + if existsEmail { + return nil, nil, ErrEmailExists + } + + hashedPassword, err := s.HashPassword(password) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + signupSource = strings.TrimSpace(strings.ToLower(signupSource)) + if signupSource == "" { + signupSource = "email" + } + grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + + user := &User{ + Email: email, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: grantPlan.Balance, + Concurrency: grantPlan.Concurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, user); err != nil { + if errors.Is(err, ErrEmailExists) { + return nil, nil, ErrEmailExists + } + return nil, nil, ErrServiceUnavailable + } + + s.postAuthUserBootstrap(ctx, user, signupSource, true) + s.assignSubscriptions(ctx, user.ID, grantPlan.Subscriptions, "auto assigned by signup defaults") + + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + return nil, nil, ErrInvitationCodeInvalid + } + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + +// ValidatePasswordCredentials checks the local password without completing the +// login flow. This is used by pending third-party account adoption flows before +// the external identity has been bound. +func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, password string) (*User, error) { + if s == nil { + return nil, ErrServiceUnavailable + } + + user, err := s.userRepo.GetByEmail(ctx, strings.TrimSpace(strings.ToLower(email))) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return nil, ErrInvalidCredentials + } + return nil, ErrServiceUnavailable + } + if !user.IsActive() { + return nil, ErrUserNotActive + } + if !s.CheckPassword(password, user.PasswordHash) { + return nil, ErrInvalidCredentials + } + return user, nil +} + +// RecordSuccessfulLogin updates last-login activity after a non-standard login +// flow finishes with a real session. +func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { + s.touchUserLogin(ctx, userID) +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index de555478..a2644fcd 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -217,6 +217,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings keys := []string{ SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, + SettingKeyForceEmailOnThirdPartySignup, SettingKeyRegistrationEmailSuffixWhitelist, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, @@ -294,6 +295,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PasswordResetEnabled: passwordResetEnabled, diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 5cf1e860..bb97c2aa 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -77,3 +77,16 @@ func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) require.Equal(t, 50, settings.TableDefaultPageSize) require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions) } + +func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyForceEmailOnThirdPartySignup: "true", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.ForceEmailOnThirdPartySignup) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index e991ebef..72db4e31 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -128,6 +128,7 @@ type DefaultSubscriptionSetting struct { type PublicSettings struct { RegistrationEnabled bool EmailVerifyEnabled bool + ForceEmailOnThirdPartySignup bool RegistrationEmailSuffixWhitelist []string PromoCodeEnabled bool PasswordResetEnabled bool -- GitLab From 6ea3f42e2f825f165c98a63fa6b5f472f6853b33 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 19:30:19 +0800 Subject: [PATCH 041/265] feat: add oauth callback email binding ui --- .../api/__tests__/auth-oauth-adoption.spec.ts | 91 +++++++ frontend/src/api/auth.ts | 98 +++++-- frontend/src/stores/app.ts | 1 + frontend/src/types/index.ts | 1 + .../src/views/auth/LinuxDoCallbackView.vue | 257 +++++++++++++++++- frontend/src/views/auth/OidcCallbackView.vue | 128 ++++++++- .../src/views/auth/WechatCallbackView.vue | 108 ++++++++ .../__tests__/LinuxDoCallbackView.spec.ts | 105 +++++++ .../auth/__tests__/OidcCallbackView.spec.ts | 71 +++++ .../auth/__tests__/WechatCallbackView.spec.ts | 88 ++++++ 10 files changed, 914 insertions(+), 34 deletions(-) diff --git a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts index 9c0b4d55..f95332fb 100644 --- a/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts +++ b/frontend/src/api/__tests__/auth-oauth-adoption.spec.ts @@ -30,6 +30,20 @@ describe('oauth adoption auth api', () => { }) }) + it('posts bind-login decisions when finalizing pending oauth bind flow', async () => { + const { completePendingOAuthBindLogin } = await import('@/api/auth') + + await completePendingOAuthBindLogin({ + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/pending/exchange', { + adopt_display_name: true, + adopt_avatar: false + }) + }) + it('posts linuxdo invitation completion with adoption decisions', async () => { const { completeLinuxDoOAuthRegistration } = await import('@/api/auth') @@ -45,6 +59,21 @@ describe('oauth adoption auth api', () => { }) }) + it('posts linuxdo create-account completion with adoption decisions', async () => { + const { createPendingLinuxDoOAuthAccount } = await import('@/api/auth') + + await createPendingLinuxDoOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: true + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/linuxdo/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: true + }) + }) + it('posts oidc invitation completion with adoption decisions', async () => { const { completeOIDCOAuthRegistration } = await import('@/api/auth') @@ -60,6 +89,21 @@ describe('oauth adoption auth api', () => { }) }) + it('posts oidc create-account completion with adoption decisions', async () => { + const { createPendingOIDCOAuthAccount } = await import('@/api/auth') + + await createPendingOIDCOAuthAccount('invite-code', { + adoptDisplayName: true, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/oidc/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: true, + adopt_avatar: false + }) + }) + it('posts wechat invitation completion with adoption decisions', async () => { const { completeWeChatOAuthRegistration } = await import('@/api/auth') @@ -75,6 +119,21 @@ describe('oauth adoption auth api', () => { }) }) + it('posts wechat create-account completion with adoption decisions', async () => { + const { createPendingWeChatOAuthAccount } = await import('@/api/auth') + + await createPendingWeChatOAuthAccount('invite-code', { + adoptDisplayName: false, + adoptAvatar: false + }) + + expect(post).toHaveBeenCalledWith('/auth/oauth/wechat/complete-registration', { + invitation_code: 'invite-code', + adopt_display_name: false, + adopt_avatar: false + }) + }) + it('classifies oauth completion results as login or bind', async () => { const { getOAuthCompletionKind } = await import('@/api/auth') @@ -82,6 +141,38 @@ describe('oauth adoption auth api', () => { expect(getOAuthCompletionKind({ redirect: '/profile' })).toBe('bind') }) + it('provides bind-login utility helpers for invitation and suggested profile states', async () => { + const { + getPendingOAuthBindLoginKind, + hasPendingOAuthSuggestedProfile, + isPendingOAuthCreateAccountRequired + } = await import('@/api/auth') + + expect(getPendingOAuthBindLoginKind({ access_token: 'access-token' })).toBe('login') + expect(getPendingOAuthBindLoginKind({ redirect: '/profile' })).toBe('bind') + expect( + isPendingOAuthCreateAccountRequired({ + error: 'invitation_required' + }) + ).toBe(true) + expect( + isPendingOAuthCreateAccountRequired({ + error: 'other' + }) + ).toBe(false) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_display_name: 'OAuth Nick' + }) + ).toBe(true) + expect( + hasPendingOAuthSuggestedProfile({ + suggested_avatar_url: 'https://cdn.example/avatar.png' + }) + ).toBe(true) + expect(hasPendingOAuthSuggestedProfile({})).toBe(false) + }) + it('prepares an oauth bind access token cookie before redirect binding', async () => { localStorage.setItem('auth_token', 'access-token-value') const setCookie = vi.fn() diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index c11bd90b..98b20154 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -193,7 +193,7 @@ export interface OAuthTokenResponse { token_type?: string } -export interface PendingOAuthExchangeResponse extends Partial { +export interface PendingOAuthBindLoginResponse extends Partial { redirect?: string error?: string adoption_required?: boolean @@ -201,6 +201,10 @@ export interface PendingOAuthExchangeResponse extends Partial +): boolean { + return completion.error === 'invitation_required' +} + +export function hasPendingOAuthSuggestedProfile( + completion: Pick< + PendingOAuthBindLoginResponse, + 'suggested_display_name' | 'suggested_avatar_url' + > +): boolean { + return Boolean(completion.suggested_display_name || completion.suggested_avatar_url) +} + export function persistOAuthTokenContext(tokens: Partial): void { if (tokens.refresh_token) { setRefreshToken(tokens.refresh_token) @@ -431,11 +456,7 @@ export async function completeLinuxDoOAuthRegistration( invitationCode: string, decision?: OAuthAdoptionDecision ): Promise { - const { data } = await apiClient.post('/auth/oauth/linuxdo/complete-registration', { - invitation_code: invitationCode, - ...serializeOAuthAdoptionDecision(decision) - }) - return data + return createPendingLinuxDoOAuthAccount(invitationCode, decision) } /** @@ -447,34 +468,68 @@ export async function completeOIDCOAuthRegistration( invitationCode: string, decision?: OAuthAdoptionDecision ): Promise { - const { data } = await apiClient.post('/auth/oauth/oidc/complete-registration', { - invitation_code: invitationCode, - ...serializeOAuthAdoptionDecision(decision) - }) - return data + return createPendingOIDCOAuthAccount(invitationCode, decision) } export async function completeWeChatOAuthRegistration( invitationCode: string, decision?: OAuthAdoptionDecision ): Promise { - const { data } = await apiClient.post('/auth/oauth/wechat/complete-registration', { - invitation_code: invitationCode, - ...serializeOAuthAdoptionDecision(decision) - }) + return createPendingWeChatOAuthAccount(invitationCode, decision) +} + +async function createPendingOAuthAccount( + provider: 'linuxdo' | 'oidc' | 'wechat', + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( + `/auth/oauth/${provider}/complete-registration`, + { + invitation_code: invitationCode, + ...serializeOAuthAdoptionDecision(decision) + } + ) return data } -export async function exchangePendingOAuthCompletion( +export async function createPendingLinuxDoOAuthAccount( + invitationCode: string, decision?: OAuthAdoptionDecision -): Promise { - const { data } = await apiClient.post( +): Promise { + return createPendingOAuthAccount('linuxdo', invitationCode, decision) +} + +export async function createPendingOIDCOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('oidc', invitationCode, decision) +} + +export async function createPendingWeChatOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision +): Promise { + return createPendingOAuthAccount('wechat', invitationCode, decision) +} + +export async function completePendingOAuthBindLogin( + decision?: OAuthAdoptionDecision +): Promise { + const { data } = await apiClient.post( '/auth/oauth/pending/exchange', serializeOAuthAdoptionDecision(decision) ) return data } +export async function exchangePendingOAuthCompletion( + decision?: OAuthAdoptionDecision +): Promise { + return completePendingOAuthBindLogin(decision) +} + export const authAPI = { login, login2FA, @@ -498,6 +553,13 @@ export const authAPI = { resetPassword, refreshToken, revokeAllSessions, + getPendingOAuthBindLoginKind, + isPendingOAuthCreateAccountRequired, + hasPendingOAuthSuggestedProfile, + completePendingOAuthBindLogin, + createPendingLinuxDoOAuthAccount, + createPendingOIDCOAuthAccount, + createPendingWeChatOAuthAccount, exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, completeOIDCOAuthRegistration, diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 1b1af87b..a8e03a51 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -316,6 +316,7 @@ export const useAppStore = defineStore('app', () => { return { registration_enabled: false, email_verify_enabled: false, + force_email_on_third_party_signup: false, registration_email_suffix_whitelist: [], promo_code_enabled: true, password_reset_enabled: false, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index a19d6c26..a4b2277b 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -142,6 +142,7 @@ export interface CustomEndpoint { export interface PublicSettings { registration_enabled: boolean email_verify_enabled: boolean + force_email_on_third_party_signup: boolean registration_email_suffix_whitelist: string[] promo_code_enabled: boolean password_reset_enabled: boolean diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index 6dc8f242..00b73868 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -11,7 +11,10 @@
-
+
+ + + +
@@ -127,11 +214,12 @@ diff --git a/frontend/src/components/admin/monitor/MonitorFiltersBar.vue b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue new file mode 100644 index 00000000..ebb06a68 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorFiltersBar.vue @@ -0,0 +1,95 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorFormDialog.vue b/frontend/src/components/admin/monitor/MonitorFormDialog.vue new file mode 100644 index 00000000..920c3f79 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorFormDialog.vue @@ -0,0 +1,297 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue new file mode 100644 index 00000000..eefe4073 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorKeyPickerDialog.vue @@ -0,0 +1,64 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue new file mode 100644 index 00000000..eccec828 --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorPrimaryModelCell.vue @@ -0,0 +1,71 @@ + + + diff --git a/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue new file mode 100644 index 00000000..02fa6e8d --- /dev/null +++ b/frontend/src/components/admin/monitor/MonitorRunResultDialog.vue @@ -0,0 +1,56 @@ + + + diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 92dcc519..23d0f4e9 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -38,7 +38,7 @@ 'sidebar-link-collapsed': sidebarCollapsed }" :title="sidebarCollapsed ? item.label : undefined" - @click="sidebarCollapsed ? undefined : toggleGroup(item)" + @click="handleGroupClick(item)" > import { computed, h, onMounted, ref, watch } from 'vue' -import { useRoute } from 'vue-router' +import { useRoute, useRouter } from 'vue-router' import { useI18n } from 'vue-i18n' import { useAdminSettingsStore, useAppStore, useAuthStore, useOnboardingStore } from '@/stores' import VersionBadge from '@/components/common/VersionBadge.vue' @@ -194,11 +194,17 @@ interface NavItem { iconSvg?: string hideInSimpleMode?: boolean children?: NavItem[] + /** + * When true, the parent item only toggles the expand/collapse state and + * does NOT navigate to its `path`. The `path` is purely a stable key. + */ + expandOnly?: boolean } const { t } = useI18n() const route = useRoute() +const router = useRouter() const appStore = useAppStore() const authStore = useAuthStore() const onboardingStore = useOnboardingStore() @@ -549,6 +555,41 @@ const ChevronDoubleRightIcon = { ) } +const SignalIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M9.348 14.651a3.75 3.75 0 010-5.303m5.304 0a3.75 3.75 0 010 5.303m-7.425 2.122a6.75 6.75 0 010-9.546m9.546 0a6.75 6.75 0 010 9.546M5.106 18.894c-3.808-3.807-3.808-9.98 0-13.788m13.788 0c3.808 3.807 3.808 9.98 0 13.788M12 12h.008v.008H12V12zm.375 0a.375.375 0 11-.75 0 .375.375 0 01.75 0z' + }) + ] + ) +} + +const PriceTagIcon = { + render: () => + h( + 'svg', + { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, + [ + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M9.568 3H5.25A2.25 2.25 0 003 5.25v4.318c0 .597.237 1.17.659 1.591l9.581 9.581c.699.699 1.78.872 2.607.33a18.095 18.095 0 005.223-5.223c.542-.827.369-1.908-.33-2.607L11.16 3.66A2.25 2.25 0 009.568 3z' + }), + h('path', { + 'stroke-linecap': 'round', + 'stroke-linejoin': 'round', + d: 'M6 6h.008v.008H6V6z' + }) + ] + ) +} + const ChevronDownIcon = { render: () => h( @@ -570,6 +611,7 @@ const userNavItems = computed((): NavItem[] => { { path: '/dashboard', label: t('nav.dashboard'), icon: DashboardIcon }, { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }, { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true }, + { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }, { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, ...(appStore.cachedPublicSettings?.payment_enabled ? [ @@ -608,6 +650,7 @@ const personalNavItems = computed((): NavItem[] => { const items: NavItem[] = [ { path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }, { path: '/usage', label: t('nav.usage'), icon: ChartIcon, hideInSimpleMode: true }, + { path: '/monitor', label: t('nav.channelStatus'), icon: SignalIcon }, { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, ...(appStore.cachedPublicSettings?.payment_enabled ? [ @@ -664,7 +707,17 @@ const adminNavItems = computed((): NavItem[] => { : []), { path: '/admin/users', label: t('nav.users'), icon: UsersIcon, hideInSimpleMode: true }, { path: '/admin/groups', label: t('nav.groups'), icon: FolderIcon, hideInSimpleMode: true }, - { path: '/admin/channels', label: t('nav.channels', '渠道管理'), icon: ChannelIcon, hideInSimpleMode: true }, + { + path: '/admin/channels', + label: t('nav.channelManagement'), + icon: ChannelIcon, + hideInSimpleMode: true, + expandOnly: true, + children: [ + { path: '/admin/channels/pricing', label: t('nav.channelPricing'), icon: PriceTagIcon }, + { path: '/admin/channels/monitor', label: t('nav.channelMonitor'), icon: SignalIcon }, + ], + }, { path: '/admin/subscriptions', label: t('nav.subscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, { path: '/admin/accounts', label: t('nav.accounts'), icon: GlobeIcon }, { path: '/admin/announcements', label: t('nav.announcements'), icon: BellIcon }, @@ -678,6 +731,7 @@ const adminNavItems = computed((): NavItem[] => { label: t('nav.orderManagement'), icon: OrderIcon, hideInSimpleMode: true, + expandOnly: true, children: [ { path: '/admin/orders/dashboard', label: t('nav.paymentDashboard'), icon: ChartIcon }, { path: '/admin/orders', label: t('nav.orderManagement'), icon: OrderIcon }, @@ -764,6 +818,28 @@ function toggleGroup(item: NavItem) { } } +/** + * Click handler for collapsible parent items. + * - When sidebar is collapsed: do nothing (children are not visible). + * - When `expandOnly` is true: only toggle expand state. + * - Otherwise (default, e.g. /admin/orders): navigate to the parent path + * (router-link semantics) and ensure the group is expanded. + */ +function handleGroupClick(item: NavItem) { + if (sidebarCollapsed.value) return + if (item.expandOnly) { + toggleGroup(item) + return + } + // Push to path and ensure expanded + if (route.path !== item.path) { + router.push(item.path) + } + if (!expandedGroups.value.has(item.path)) { + expandedGroups.value.add(item.path) + } +} + // Initialize theme const savedTheme = localStorage.getItem('theme') if ( diff --git a/frontend/src/components/user/MonitorDetailDialog.vue b/frontend/src/components/user/MonitorDetailDialog.vue new file mode 100644 index 00000000..564f461b --- /dev/null +++ b/frontend/src/components/user/MonitorDetailDialog.vue @@ -0,0 +1,114 @@ + + + diff --git a/frontend/src/components/user/MonitorPrimaryModelCell.vue b/frontend/src/components/user/MonitorPrimaryModelCell.vue new file mode 100644 index 00000000..32620b2a --- /dev/null +++ b/frontend/src/components/user/MonitorPrimaryModelCell.vue @@ -0,0 +1,71 @@ + + + diff --git a/frontend/src/composables/useChannelMonitorFormat.ts b/frontend/src/composables/useChannelMonitorFormat.ts new file mode 100644 index 00000000..fbb310fa --- /dev/null +++ b/frontend/src/composables/useChannelMonitorFormat.ts @@ -0,0 +1,97 @@ +/** + * Shared formatting helpers for channel monitor views (admin + user). + * + * Centralises: + * - status / provider label + badge class lookups + * - latency / availability / percent number formatting + * + * i18n keys live under `monitorCommon.*` so admin and user views share the + * same translation source. + */ + +import { useI18n } from 'vue-i18n' +import type { MonitorStatus, Provider } from '@/api/admin/channelMonitor' +import { + PROVIDER_OPENAI, + PROVIDER_ANTHROPIC, + PROVIDER_GEMINI, + STATUS_OPERATIONAL, + STATUS_DEGRADED, + STATUS_FAILED, + STATUS_ERROR, +} from '@/constants/channelMonitor' + +const NEUTRAL_BADGE = 'bg-gray-100 text-gray-800 dark:bg-dark-700 dark:text-gray-300' + +export interface AvailabilityRow { + primary_status: MonitorStatus | '' + availability_7d: number | null | undefined +} + +export function useChannelMonitorFormat() { + const { t } = useI18n() + + function statusLabel(s: MonitorStatus | ''): string { + if (!s) return t('monitorCommon.status.unknown') + return t(`monitorCommon.status.${s}`) + } + + function statusBadgeClass(s: MonitorStatus | ''): string { + switch (s) { + case STATUS_OPERATIONAL: + return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300' + case STATUS_DEGRADED: + return 'bg-yellow-100 text-yellow-800 dark:bg-yellow-900/30 dark:text-yellow-300' + case STATUS_FAILED: + return 'bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-300' + case STATUS_ERROR: + default: + return NEUTRAL_BADGE + } + } + + function providerLabel(p: Provider | string): string { + if (p === PROVIDER_OPENAI || p === PROVIDER_ANTHROPIC || p === PROVIDER_GEMINI) { + return t(`monitorCommon.providers.${p}`) + } + return p || '-' + } + + function providerBadgeClass(p: Provider | string): string { + switch (p) { + case PROVIDER_OPENAI: + return 'bg-green-100 text-green-800 dark:bg-green-900/30 dark:text-green-300' + case PROVIDER_ANTHROPIC: + return 'bg-orange-100 text-orange-800 dark:bg-orange-900/30 dark:text-orange-300' + case PROVIDER_GEMINI: + return 'bg-blue-100 text-blue-800 dark:bg-blue-900/30 dark:text-blue-300' + default: + return NEUTRAL_BADGE + } + } + + function formatLatency(ms: number | null | undefined): string { + if (ms == null) return t('monitorCommon.latencyEmpty') + return String(Math.round(ms)) + } + + function formatPercent(v: number | null | undefined): string { + if (v == null || Number.isNaN(v)) return '-' + return `${v.toFixed(2)}%` + } + + function formatAvailability(row: AvailabilityRow): string { + if (!row.primary_status) return '-' + return formatPercent(row.availability_7d) + } + + return { + statusLabel, + statusBadgeClass, + providerLabel, + providerBadgeClass, + formatLatency, + formatPercent, + formatAvailability, + } +} diff --git a/frontend/src/constants/channelMonitor.ts b/frontend/src/constants/channelMonitor.ts new file mode 100644 index 00000000..7523a878 --- /dev/null +++ b/frontend/src/constants/channelMonitor.ts @@ -0,0 +1,35 @@ +/** + * Channel monitor shared constants. + * + * Single source of truth for provider/status string values used by both the + * admin (`views/admin/ChannelMonitorView.vue`) and user-facing + * (`views/user/ChannelStatusView.vue`) screens, plus the shared composable + * `useChannelMonitorFormat`. + */ + +import type { Provider, MonitorStatus } from '@/api/admin/channelMonitor' + +export const PROVIDER_OPENAI: Provider = 'openai' +export const PROVIDER_ANTHROPIC: Provider = 'anthropic' +export const PROVIDER_GEMINI: Provider = 'gemini' + +export const PROVIDERS: readonly Provider[] = [ + PROVIDER_OPENAI, + PROVIDER_ANTHROPIC, + PROVIDER_GEMINI, +] + +export const STATUS_OPERATIONAL: MonitorStatus = 'operational' +export const STATUS_DEGRADED: MonitorStatus = 'degraded' +export const STATUS_FAILED: MonitorStatus = 'failed' +export const STATUS_ERROR: MonitorStatus = 'error' + +export const MONITOR_STATUSES: readonly MonitorStatus[] = [ + STATUS_OPERATIONAL, + STATUS_DEGRADED, + STATUS_FAILED, + STATUS_ERROR, +] + +/** Default polling interval (seconds) for new monitors. */ +export const DEFAULT_INTERVAL_SECONDS = 60 diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 1b7ffa81..32fbce19 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -245,6 +245,7 @@ export default { // Common common: { loading: 'Loading...', + submitting: 'Submitting...', justNow: 'just now', save: 'Save', saved: 'Saved successfully', @@ -363,7 +364,11 @@ export default { orderManagement: 'Orders', paymentDashboard: 'Payment Dashboard', paymentConfig: 'Payment Config', - paymentPlans: 'Plans' + paymentPlans: 'Plans', + channelManagement: 'Channels', + channelPricing: 'Channel Pricing', + channelMonitor: 'Channel Monitor', + channelStatus: 'Channel Status', }, // Auth @@ -846,6 +851,58 @@ export default { userAgent: 'User-Agent' }, + // Shared keys for channel monitor (admin + user views) + monitorCommon: { + status: { + operational: 'Operational', + degraded: 'Degraded', + failed: 'Failed', + error: 'Error', + unknown: '-' + }, + providers: { + openai: 'OpenAI', + anthropic: 'Anthropic', + gemini: 'Gemini' + }, + extraModelsHeader: 'Extra Models', + extraModelsEmpty: 'No extra models', + latencyEmpty: '-' + }, + + // Channel Status (user-facing read-only view) + channelStatus: { + title: 'Channel Status', + description: 'Inspect channel availability, latency and recent status', + searchPlaceholder: 'Search channels...', + allProviders: 'All Providers', + loadError: 'Failed to load channel status', + detailLoadError: 'Failed to load channel detail', + detailTitle: 'Channel Detail', + closeDetail: 'Close', + columns: { + name: 'Name', + provider: 'Provider', + groupName: 'Group', + primaryModel: 'Primary Model', + availability7d: '7d Availability', + latency: 'Latency (ms)' + }, + detailColumns: { + model: 'Model', + latestStatus: 'Latest Status', + latestLatency: 'Latest Latency (ms)', + availability7d: '7d Availability', + availability15d: '15d Availability', + availability30d: '30d Availability', + avgLatency7d: '7d Avg Latency (ms)' + }, + empty: { + title: 'No channels available', + description: 'No monitored channels have been configured yet.' + } + }, + // Redeem redeem: { title: 'Redeem Code', @@ -2014,6 +2071,69 @@ export default { } }, + // Channel Monitor + channelMonitor: { + title: 'Channel Monitor', + description: 'Monitor channel availability, latency and status', + searchPlaceholder: 'Search monitor name...', + allProviders: 'All Providers', + allStatus: 'All Status', + enabledFilter: 'Enabled', + onlyEnabled: 'Enabled only', + onlyDisabled: 'Disabled only', + createButton: 'Create Monitor', + createTitle: 'Create Channel Monitor', + editTitle: 'Edit Channel Monitor', + runNow: 'Run Now', + runSuccess: 'Check completed', + runFailed: 'Check failed', + apiKeyDecryptFailed: 'API Key decryption failed. Please re-edit this monitor with a fresh key.', + createSuccess: 'Monitor created', + updateSuccess: 'Monitor updated', + deleteSuccess: 'Monitor deleted', + loadError: 'Failed to load monitors', + deleteConfirm: 'Are you sure you want to delete monitor "{name}"? This action cannot be undone.', + nameRequired: 'Please enter a monitor name', + primaryModelRequired: 'Please enter a primary model', + columns: { + name: 'Name', + provider: 'Provider', + primaryModel: 'Primary Model', + availability7d: '7d Availability', + latency: 'Latency (ms)', + enabled: 'Enabled', + actions: 'Actions' + }, + form: { + name: 'Name', + namePlaceholder: 'Enter monitor name', + provider: 'Provider', + endpoint: 'Endpoint', + endpointPlaceholder: 'https://api.example.com', + useCurrentDomain: 'Use current service', + apiKey: 'API Key', + apiKeyPlaceholder: 'Enter API Key', + apiKeyEditPlaceholder: 'Leave blank to keep current key', + useMyKey: 'Use my key', + selectKeyTitle: 'Select my API Key', + selectKeyHint: 'Only your active, non-expired keys are listed.', + noActiveKey: 'No active API keys available', + primaryModel: 'Primary Model', + primaryModelPlaceholder: 'gpt-4o-mini', + extraModels: 'Extra Models', + extraModelsPlaceholder: 'Press Enter to add extra model', + groupName: 'Group Name', + groupNamePlaceholder: 'Optional, used to group rows in user view', + intervalSeconds: 'Interval (seconds)', + intervalSecondsHint: 'Range: 15 - 3600 seconds', + enabled: 'Enable monitor', + kindRequired: 'Please select a provider' + }, + runResultTitle: 'Check Result', + noMonitorsYet: 'No monitors yet', + createFirstMonitor: 'Create your first monitor to track channel availability' + }, + // Subscriptions subscriptions: { title: 'Subscription Management', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index beb6841f..dd3af363 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -245,6 +245,7 @@ export default { // Common common: { loading: '加载中...', + submitting: '提交中...', justNow: '刚刚', save: '保存', saved: '保存成功', @@ -363,7 +364,11 @@ export default { orderManagement: '订单管理', paymentDashboard: '支付概览', paymentConfig: '支付配置', - paymentPlans: '订阅套餐' + paymentPlans: '订阅套餐', + channelManagement: '渠道管理', + channelPricing: '渠道定价', + channelMonitor: '渠道监控', + channelStatus: '渠道状态', }, // Auth @@ -850,6 +855,58 @@ export default { userAgent: 'User-Agent' }, + // Shared keys for channel monitor (admin + user views) + monitorCommon: { + status: { + operational: '正常', + degraded: '降级', + failed: '失败', + error: '错误', + unknown: '-' + }, + providers: { + openai: 'OpenAI', + anthropic: 'Anthropic', + gemini: 'Gemini' + }, + extraModelsHeader: '附加模型', + extraModelsEmpty: '无附加模型', + latencyEmpty: '-' + }, + + // Channel Status (user-facing read-only view) + channelStatus: { + title: '渠道状态', + description: '查看渠道可用性、延迟和近期状态', + searchPlaceholder: '搜索渠道...', + allProviders: '全部供应商', + loadError: '加载渠道状态失败', + detailLoadError: '加载渠道详情失败', + detailTitle: '渠道详情', + closeDetail: '关闭', + columns: { + name: '名称', + provider: '供应商', + groupName: '分组', + primaryModel: '主模型', + availability7d: '7 天可用率', + latency: '延迟 (ms)' + }, + detailColumns: { + model: '模型', + latestStatus: '最新状态', + latestLatency: '最新延迟 (ms)', + availability7d: '7 天可用率', + availability15d: '15 天可用率', + availability30d: '30 天可用率', + avgLatency7d: '7 天平均延迟 (ms)' + }, + empty: { + title: '暂无可显示的渠道', + description: '管理员尚未配置可监控的渠道。' + } + }, + // Redeem redeem: { title: '兑换码', @@ -2093,6 +2150,69 @@ export default { } }, + // Channel Monitor + channelMonitor: { + title: '渠道监控', + description: '监测各渠道的可用性、延迟和状态', + searchPlaceholder: '搜索监控名称...', + allProviders: '全部供应商', + allStatus: '全部状态', + enabledFilter: '启用状态', + onlyEnabled: '仅启用', + onlyDisabled: '仅禁用', + createButton: '新增监控', + createTitle: '新增渠道监控', + editTitle: '编辑渠道监控', + runNow: '立即检测', + runSuccess: '检测完成', + runFailed: '检测失败', + apiKeyDecryptFailed: 'API Key 解密失败,请重新编辑该监控并填入新的 Key', + createSuccess: '监控创建成功', + updateSuccess: '监控更新成功', + deleteSuccess: '监控删除成功', + loadError: '加载监控列表失败', + deleteConfirm: '确定要删除监控「{name}」吗?此操作不可撤销。', + nameRequired: '请输入监控名称', + primaryModelRequired: '请输入主模型', + columns: { + name: '名称', + provider: '供应商', + primaryModel: '主模型', + availability7d: '7 天可用率', + latency: '延迟 (ms)', + enabled: '启用', + actions: '操作' + }, + form: { + name: '名称', + namePlaceholder: '输入监控名称', + provider: '供应商', + endpoint: '上游地址', + endpointPlaceholder: 'https://api.example.com', + useCurrentDomain: '使用当前服务', + apiKey: 'API Key', + apiKeyPlaceholder: '请输入 API Key', + apiKeyEditPlaceholder: '留空表示不修改', + useMyKey: '使用我的 Key', + selectKeyTitle: '选择我的 API Key', + selectKeyHint: '仅显示当前账号下处于「启用」状态且未过期的 Key。', + noActiveKey: '没有可用的启用状态 Key', + primaryModel: '主模型', + primaryModelPlaceholder: 'gpt-4o-mini', + extraModels: '附加模型', + extraModelsPlaceholder: '回车添加附加模型', + groupName: '分组名称', + groupNamePlaceholder: '可选,用于在用户视图中聚合显示', + intervalSeconds: '检测间隔 (秒)', + intervalSecondsHint: '范围:15 - 3600 秒', + enabled: '启用监控', + kindRequired: '请选择供应商' + }, + runResultTitle: '检测结果', + noMonitorsYet: '暂无监控', + createFirstMonitor: '创建第一个监控来跟踪渠道可用性' + }, + // Subscriptions Management subscriptions: { title: '订阅管理', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index b97ccb5d..491a984d 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -360,6 +360,10 @@ const routes: RouteRecordRaw[] = [ }, { path: '/admin/channels', + redirect: '/admin/channels/pricing' + }, + { + path: '/admin/channels/pricing', name: 'AdminChannels', component: () => import('@/views/admin/ChannelsView.vue'), meta: { @@ -370,6 +374,29 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.channels.description' } }, + { + path: '/admin/channels/monitor', + name: 'AdminChannelMonitor', + component: () => import('@/views/admin/ChannelMonitorView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Channel Monitor', + titleKey: 'admin.channelMonitor.title', + descriptionKey: 'admin.channelMonitor.description' + } + }, + { + path: '/monitor', + name: 'ChannelStatus', + component: () => import('@/views/user/ChannelStatusView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: false, + title: 'Channel Status', + titleKey: 'nav.channelStatus' + } + }, { path: '/admin/subscriptions', name: 'AdminSubscriptions', diff --git a/frontend/src/views/admin/ChannelMonitorView.vue b/frontend/src/views/admin/ChannelMonitorView.vue new file mode 100644 index 00000000..8f0a1e2f --- /dev/null +++ b/frontend/src/views/admin/ChannelMonitorView.vue @@ -0,0 +1,295 @@ + + + diff --git a/frontend/src/views/user/ChannelStatusView.vue b/frontend/src/views/user/ChannelStatusView.vue new file mode 100644 index 00000000..9f5fe8d1 --- /dev/null +++ b/frontend/src/views/user/ChannelStatusView.vue @@ -0,0 +1,208 @@ + + + -- GitLab From 58b2cc380fefc180d96c3baf10d3214026c1341c Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:22:00 +0800 Subject: [PATCH 047/265] test: harden payment result resume flow --- frontend/src/views/user/PaymentResultView.vue | 50 +++---- .../user/__tests__/PaymentResultView.spec.ts | 132 ++++++++++++++++++ 2 files changed, 157 insertions(+), 25 deletions(-) create mode 100644 frontend/src/views/user/__tests__/PaymentResultView.spec.ts diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index 6431ddf6..e1db3ce2 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -94,6 +94,7 @@ import { ref, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useRoute, useRouter } from 'vue-router' import OrderStatusBadge from '@/components/payment/OrderStatusBadge.vue' +import { PAYMENT_RECOVERY_STORAGE_KEY, readPaymentRecoverySnapshot } from '@/components/payment/paymentFlow' import { usePaymentStore } from '@/stores/payment' import { paymentAPI } from '@/api/payment' import type { PaymentOrder } from '@/types/payment' @@ -129,46 +130,46 @@ const feeAmount = computed(() => { }) const isSuccess = computed(() => { - // Always prioritize actual order status from backend - if (order.value) { - return SUCCESS_STATUSES.has(order.value.status) - } - // Fallback only when order not loaded - if (route.query.status === 'success') return true - if (route.query.trade_status === 'TRADE_SUCCESS') return true - return false + return !!order.value && SUCCESS_STATUSES.has(order.value.status) }) -/** Extract numeric order ID from out_trade_no like "sub2_46" → 46 */ -function parseOutTradeNo(outTradeNo: string): number { - const match = outTradeNo.match(/_(\d+)$/) - return match ? Number(match[1]) : 0 -} - onMounted(async () => { - // Try order_id first (internal navigation from QRCode/Stripe pages) + const resumeToken = typeof route.query.resume_token === 'string' + ? route.query.resume_token + : '' let orderId = Number(route.query.order_id) || 0 const outTradeNo = String(route.query.out_trade_no || '') - // Fallback: EasyPay return URL with out_trade_no - if (!orderId && outTradeNo) { - orderId = parseOutTradeNo(outTradeNo) - // Store return info for display when order lookup fails + if (!orderId && resumeToken && typeof window !== 'undefined') { + const restored = readPaymentRecoverySnapshot( + window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY), + { resumeToken }, + ) + if (restored?.orderId) { + orderId = restored.orderId + } + } + + if (orderId) { + try { + order.value = await paymentStore.pollOrderStatus(orderId) + } catch (_err: unknown) { + // Order lookup failed, will try legacy fallback below when possible. + } + } + + if (!order.value && outTradeNo) { returnInfo.value = { outTradeNo, money: String(route.query.money || ''), type: String(route.query.type || ''), tradeStatus: String(route.query.trade_status || ''), } - } - // Verify payment via public endpoint (works without login) - if (outTradeNo) { try { const result = await paymentAPI.verifyOrderPublic(outTradeNo) order.value = result.data } catch (_err: unknown) { - // Public verify failed, try authenticated endpoint if logged in try { const result = await paymentAPI.verifyOrder(outTradeNo) order.value = result.data @@ -176,12 +177,11 @@ onMounted(async () => { } } - // Normal order lookup by ID (if verify didn't load the order) if (!order.value && orderId) { try { order.value = await paymentStore.pollOrderStatus(orderId) } catch (_err: unknown) { - // Order lookup failed, will show returnInfo fallback + // Order lookup failed, will show returnInfo fallback. } } loading.value = false diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts new file mode 100644 index 00000000..b06217ab --- /dev/null +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -0,0 +1,132 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +const routeState = vi.hoisted(() => ({ + query: {} as Record, +})) + +const routerPush = vi.hoisted(() => vi.fn()) +const pollOrderStatus = vi.hoisted(() => vi.fn()) +const verifyOrderPublic = vi.hoisted(() => vi.fn()) +const verifyOrder = vi.hoisted(() => vi.fn()) + +vi.mock('vue-router', async () => { + const actual = await vi.importActual('vue-router') + return { + ...actual, + useRoute: () => routeState, + useRouter: () => ({ push: routerPush }), + } +}) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key, + }), + } +}) + +vi.mock('@/stores/payment', () => ({ + usePaymentStore: () => ({ + pollOrderStatus, + }), +})) + +vi.mock('@/api/payment', () => ({ + paymentAPI: { + verifyOrderPublic, + verifyOrder, + }, +})) + +import PaymentResultView from '../PaymentResultView.vue' +import { PAYMENT_RECOVERY_STORAGE_KEY } from '@/components/payment/paymentFlow' + +const orderFactory = (status: string) => ({ + id: 42, + user_id: 9, + amount: 88, + pay_amount: 88, + fee_rate: 0, + payment_type: 'alipay', + out_trade_no: 'sub2_20260420abcd1234', + status, + order_type: 'balance', + created_at: '2026-04-20T12:00:00Z', + expires_at: '2026-04-20T12:30:00Z', + refund_amount: 0, +}) + +describe('PaymentResultView', () => { + beforeEach(() => { + routeState.query = {} + routerPush.mockReset() + pollOrderStatus.mockReset() + verifyOrderPublic.mockReset() + verifyOrder.mockReset() + window.localStorage.clear() + }) + + it('restores order id from a matching resume token and does not trust query success flags', async () => { + routeState.query = { + resume_token: 'resume-42', + status: 'success', + } + window.localStorage.setItem(PAYMENT_RECOVERY_STORAGE_KEY, JSON.stringify({ + orderId: 42, + amount: 88, + qrCode: '', + expiresAt: '2099-01-01T00:10:00.000Z', + paymentType: 'alipay', + payUrl: 'https://pay.example.com/session/42', + clientSecret: '', + payAmount: 88, + orderType: 'balance', + paymentMode: 'redirect', + resumeToken: 'resume-42', + createdAt: Date.UTC(2099, 0, 1, 0, 0, 0), + })) + pollOrderStatus.mockResolvedValue(orderFactory('PENDING')) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(pollOrderStatus).toHaveBeenCalledWith(42) + expect(verifyOrderPublic).not.toHaveBeenCalled() + expect(wrapper.text()).toContain('payment.result.failed') + expect(wrapper.text()).not.toContain('payment.result.success') + }) + + it('keeps legacy out_trade_no verification as a fallback when no order context is available', async () => { + routeState.query = { + out_trade_no: 'legacy-123', + trade_status: 'TRADE_SUCCESS', + } + verifyOrderPublic.mockResolvedValue({ + data: orderFactory('PAID'), + }) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123') + expect(wrapper.text()).toContain('payment.result.success') + }) +}) -- GitLab From 40d4e167cd8093bc3f21d4dbe205946df9c0aead Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 20 Apr 2026 20:06:53 +0800 Subject: [PATCH 048/265] feat(payment): i18n payment error codes and label localization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pairs with the backend structured payment errors (reason + metadata). The frontend now maps reason codes to localized messages with metadata as interpolation variables, and automatically localizes raw config-field names (e.g. "certSerial" → "证书序列号") using the existing UI-label i18n namespace. - frontend/src/utils/apiError.ts - extractApiErrorCode now prefers the string `reason` over the numeric HTTP `code`; reason is granular enough to drive i18n lookup, HTTP code is not. - New extractApiErrorMetadata to pull interpolation params off the error. - New extractI18nErrorMessage(err, t, namespace, fallback): looks up `.` in i18n and substitutes metadata. Before substitution, `metadata.key` and `metadata.keys` (slash-joined) are re-translated through `admin.settings.payment.field_` so users see "缺少必填项:证书序列号" instead of "缺少必填项:certSerial". - frontend/src/i18n/locales/{zh,en}.ts - Add payment.errors entries for every structured reason code returned by the backend (PAYMENT_DISABLED, INVALID_AMOUNT, TOO_MANY_PENDING, DAILY_LIMIT_EXCEEDED, NO_AVAILABLE_INSTANCE, PAYMENT_PROVIDER_MISCONFIGURED, WXPAY_CONFIG_MISSING_KEY / INVALID_KEY_LENGTH / INVALID_KEY, NOT_FOUND, FORBIDDEN, CONFLICT, INVALID_ORDER_TYPE, INVALID_STATUS, BALANCE_NOT_ENOUGH, REFUND_AMOUNT_EXCEEDED, REFUND_FAILED, and more), with placeholders for template variables. - 13 payment-related Vue files - Migrate catch-block error reporting from extractApiErrorMessage to extractI18nErrorMessage(err, t, 'payment.errors', fallback). - Remove the ad-hoc paymentErrorMap computed in SettingsView.vue, which the new helper supersedes (it reads i18n directly via t). - frontend/src/components/payment/providerConfig.ts - wxpay: publicKey and publicKeyId are now required (was optional), matching the pubkey-only verifier direction; certSerial is already required. This PR is drop-in safe: reason-preferring extractApiErrorCode is backward compatible with callers that pass their own i18nMap, and error codes missing from i18n fall back to the existing message-based path. --- .../components/payment/PaymentQRDialog.vue | 4 +- .../components/payment/PaymentStatusPanel.vue | 4 +- .../payment/StripePaymentInline.vue | 8 +- .../src/components/payment/providerConfig.ts | 4 +- frontend/src/i18n/locales/en.ts | 26 ++++++ frontend/src/i18n/locales/zh.ts | 26 ++++++ frontend/src/utils/apiError.ts | 84 ++++++++++++++++++- frontend/src/views/admin/SettingsView.vue | 18 ++-- .../views/admin/orders/AdminOrdersView.vue | 10 +-- .../orders/AdminPaymentDashboardView.vue | 4 +- .../admin/orders/AdminPaymentPlansView.vue | 8 +- frontend/src/views/user/PaymentQRCodeView.vue | 4 +- frontend/src/views/user/PaymentView.vue | 6 +- frontend/src/views/user/StripePaymentView.vue | 6 +- frontend/src/views/user/StripePopupView.vue | 4 +- frontend/src/views/user/UserOrdersView.vue | 8 +- 16 files changed, 177 insertions(+), 47 deletions(-) diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue index db90c3b6..09d273cc 100644 --- a/frontend/src/components/payment/PaymentQRDialog.vue +++ b/frontend/src/components/payment/PaymentQRDialog.vue @@ -78,7 +78,7 @@ import Icon from '@/components/icons/Icon.vue' import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import QRCode from 'qrcode' @@ -222,7 +222,7 @@ async function handleCancel() { cleanup() emit('close') } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue index 17541e59..53989dee 100644 --- a/frontend/src/components/payment/PaymentStatusPanel.vue +++ b/frontend/src/components/payment/PaymentStatusPanel.vue @@ -124,7 +124,7 @@ import { useI18n } from 'vue-i18n' import { usePaymentStore } from '@/stores/payment' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' import type { PaymentOrder } from '@/types/payment' import Icon from '@/components/icons/Icon.vue' @@ -242,7 +242,7 @@ async function handleCancel() { cleanup() outcome.value = 'cancelled' } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/components/payment/StripePaymentInline.vue b/frontend/src/components/payment/StripePaymentInline.vue index 3ddff8c8..bdb0dd6b 100644 --- a/frontend/src/components/payment/StripePaymentInline.vue +++ b/frontend/src/components/payment/StripePaymentInline.vue @@ -67,7 +67,7 @@ import { ref, onMounted, nextTick } from 'vue' import { useI18n } from 'vue-i18n' import { useRouter } from 'vue-router' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { paymentAPI } from '@/api/payment' import { useAppStore } from '@/stores' import { getPaymentPopupFeatures } from '@/components/payment/providerConfig' @@ -132,7 +132,7 @@ onMounted(async () => { selectedType.value = event.value.type }) } catch (err: unknown) { - initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed')) + initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed')) } finally { loading.value = false } @@ -186,7 +186,7 @@ async function handlePay() { emit('success') } } catch (err: unknown) { - error.value = extractApiErrorMessage(err, t('payment.result.failed')) + error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed')) } finally { submitting.value = false } @@ -199,7 +199,7 @@ async function handleCancel() { await paymentAPI.cancelOrder(props.orderId) emit('back') } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts index bf2d4177..f4f5acdc 100644 --- a/frontend/src/components/payment/providerConfig.ts +++ b/frontend/src/components/payment/providerConfig.ts @@ -99,9 +99,9 @@ export const PROVIDER_CONFIG_FIELDS: Record = { { key: 'mchId', label: '', sensitive: false }, { key: 'privateKey', label: '', sensitive: true }, { key: 'apiV3Key', label: '', sensitive: true }, + { key: 'certSerial', label: '', sensitive: false }, { key: 'publicKey', label: '', sensitive: true }, - { key: 'publicKeyId', label: '', sensitive: false, optional: true }, - { key: 'certSerial', label: '', sensitive: false, optional: true }, + { key: 'publicKeyId', label: '', sensitive: false }, ], stripe: [ { key: 'secretKey', label: '', sensitive: true }, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c0a17d96..8213cb0f 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -5432,7 +5432,33 @@ export default { errors: { tooManyPending: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.', cancelRateLimited: 'Too many cancellations. Please try again later.', + // Structured error codes (reason strings from backend ApplicationError) + PAYMENT_DISABLED: 'Payment system is disabled.', + USER_INACTIVE: 'Your account is disabled.', + BALANCE_PAYMENT_DISABLED: 'Balance recharge has been disabled.', + INVALID_AMOUNT: 'Invalid amount.', + INVALID_INPUT: 'Invalid request.', + PLAN_NOT_AVAILABLE: 'Plan not found or no longer available.', + GROUP_NOT_FOUND: 'Subscription group is no longer available.', + GROUP_TYPE_MISMATCH: 'Group is not a subscription type.', + TOO_MANY_PENDING: 'Too many pending orders (max {max}). Please complete or cancel existing orders first.', + DAILY_LIMIT_EXCEEDED: 'Daily recharge limit reached. Remaining: {remaining}.', + PAYMENT_GATEWAY_ERROR: 'Payment method is unavailable.', + NO_AVAILABLE_INSTANCE: 'No payment channel available right now.', + PAYMENT_PROVIDER_MISCONFIGURED: 'Payment provider misconfigured. Please contact an administrator.', + WXPAY_CONFIG_MISSING_KEY: 'WeChat Pay config missing required key: {key}.', + WXPAY_CONFIG_INVALID_KEY_LENGTH: 'WeChat Pay {key} length is invalid (expected {expected} bytes, got {actual}).', + WXPAY_CONFIG_INVALID_KEY: 'WeChat Pay {key} is malformed. Make sure you copied the full PEM content.', PENDING_ORDERS: 'This provider has pending orders. Please wait for them to complete before making changes.', + CANCEL_RATE_LIMITED: 'Too many cancellations. Please try again later.', + NOT_FOUND: 'Order not found.', + FORBIDDEN: 'No permission for this order.', + CONFLICT: 'Order status has changed. Please refresh.', + INVALID_ORDER_TYPE: 'Only balance orders can request a refund.', + INVALID_STATUS: 'The current order status does not allow this operation.', + BALANCE_NOT_ENOUGH: 'Refund amount exceeds balance.', + REFUND_AMOUNT_EXCEEDED: 'Refund amount exceeds the recharge amount.', + REFUND_FAILED: 'Refund failed.', }, stripePay: 'Pay Now', stripeSuccessProcessing: 'Payment successful, processing your order...', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index ba9edd7f..5f936965 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -5620,7 +5620,33 @@ export default { errors: { tooManyPending: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单', cancelRateLimited: '取消订单过于频繁,请稍后再试', + // Structured error codes (reason strings from backend ApplicationError) + PAYMENT_DISABLED: '支付系统已关闭', + USER_INACTIVE: '账号已被禁用', + BALANCE_PAYMENT_DISABLED: '余额充值功能已关闭', + INVALID_AMOUNT: '金额无效', + INVALID_INPUT: '参数有误', + PLAN_NOT_AVAILABLE: '套餐不存在或已下架', + GROUP_NOT_FOUND: '订阅分组不可用', + GROUP_TYPE_MISMATCH: '分组类型不是订阅类型', + TOO_MANY_PENDING: '待支付订单过多(最多 {max} 个),请先完成或取消现有订单', + DAILY_LIMIT_EXCEEDED: '今日充值已达上限,剩余额度 {remaining}', + PAYMENT_GATEWAY_ERROR: '支付方式不可用', + NO_AVAILABLE_INSTANCE: '暂无可用的支付通道', + PAYMENT_PROVIDER_MISCONFIGURED: '支付通道配置错误,请联系管理员', + WXPAY_CONFIG_MISSING_KEY: '微信支付配置缺少必填项:{key}', + WXPAY_CONFIG_INVALID_KEY_LENGTH: '微信支付 {key} 长度错误,应为 {expected} 字节(实际 {actual})', + WXPAY_CONFIG_INVALID_KEY: '微信支付 {key} 格式错误,请确认复制了完整的 PEM 内容', PENDING_ORDERS: '该服务商有未完成的订单,请等待订单完成后再操作', + CANCEL_RATE_LIMITED: '取消订单过于频繁,请稍后再试', + NOT_FOUND: '订单不存在', + FORBIDDEN: '无权限操作此订单', + CONFLICT: '订单状态已变更,请刷新', + INVALID_ORDER_TYPE: '仅余额订单可申请退款', + INVALID_STATUS: '当前订单状态不允许此操作', + BALANCE_NOT_ENOUGH: '退款金额超过余额', + REFUND_AMOUNT_EXCEEDED: '退款金额超过充值金额', + REFUND_FAILED: '退款失败', }, stripePay: '立即支付', stripeSuccessProcessing: '支付成功,正在处理订单...', diff --git a/frontend/src/utils/apiError.ts b/frontend/src/utils/apiError.ts index e1fe0c30..07a17aca 100644 --- a/frontend/src/utils/apiError.ts +++ b/frontend/src/utils/apiError.ts @@ -23,14 +23,96 @@ interface ApiErrorLike { /** * Extract the error code from an API error object. + * + * Prefers the string `reason` (e.g. "PAYMENT_PROVIDER_MISCONFIGURED") over the + * numeric HTTP `code`, because reason is granular enough to drive i18n lookup + * while HTTP code is not. */ export function extractApiErrorCode(err: unknown): string | undefined { if (!err || typeof err !== 'object') return undefined const e = err as ApiErrorLike - const code = e.code ?? e.reason ?? e.response?.data?.code + const code = e.reason ?? e.code ?? e.response?.data?.code return code != null ? String(code) : undefined } +/** + * Extract metadata (interpolation params) from an API error object. + * Backend errors carry `metadata` with template variables that fill i18n placeholders. + */ +export function extractApiErrorMetadata(err: unknown): Record | undefined { + if (!err || typeof err !== 'object') return undefined + const e = err as ApiErrorLike + return e.metadata +} + +type TranslateFn = (key: string, params?: Record) => string +type TranslateWithExistsFn = TranslateFn & { te?: (key: string) => boolean } + +/** + * Translate a value via i18n if a matching key exists, otherwise return the original. + * Example: "certSerial" → t('admin.settings.payment.field_certSerial') → "证书序列号". + */ +function tryTranslate(t: TranslateFn, key: string, fallback: string): string { + const translated = t(key) + if (translated === key) return fallback + const te = (t as TranslateWithExistsFn).te + if (te && !te(key)) return fallback + return translated +} + +/** + * Replace raw config field names in metadata (e.g. "certSerial") with their + * localized UI labels (e.g. "证书序列号"), using the provider-config field i18n namespace. + * Handles both single `key` and `/`-joined `keys` patterns used by wxpay errors. + */ +function localizeMetadata(metadata: Record, t: TranslateFn): Record { + const out: Record = { ...metadata } + if (typeof out.key === 'string') { + out.key = tryTranslate(t, `admin.settings.payment.field_${out.key}`, out.key) + } + if (typeof out.keys === 'string') { + out.keys = out.keys + .split('/') + .map(k => tryTranslate(t, `admin.settings.payment.field_${k}`, k)) + .join(' / ') + } + return out +} + +/** + * Extract a localized error message from an API error by looking up + * `.` in i18n and substituting metadata as placeholders. + * + * Config-field names in metadata (`key` / `keys`) are automatically translated + * to their UI labels before substitution, so error messages read like + * "缺少必填项:证书序列号" instead of "缺少必填项:certSerial". + * + * @param err - The caught error + * @param t - Vue i18n translate function + * @param namespace- i18n key prefix, e.g. "payment.errors" + * @param fallback - Fallback key or plain string if no localized mapping exists + */ +export function extractI18nErrorMessage( + err: unknown, + t: TranslateFn, + namespace: string, + fallback: string, +): string { + const code = extractApiErrorCode(err) + if (code) { + const key = `${namespace}.${code}` + const rawMetadata = extractApiErrorMetadata(err) ?? {} + const metadata = localizeMetadata(rawMetadata, t) + const translated = t(key, metadata) + // Vue i18n returns the key itself when missing; detect that and fall back. + if (translated !== key) return translated + // If the framework exposes `te`, use it to double-check. + const te = (t as TranslateWithExistsFn).te + if (te && te(key)) return translated + } + return extractApiErrorMessage(err, fallback) +} + /** * Extract a displayable error message from an API error. * diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index ee6a4c6d..27cb1f0c 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -2850,7 +2850,7 @@ import ProxySelector from '@/components/common/ProxySelector.vue' import ImageUpload from '@/components/common/ImageUpload.vue' import BackupSettings from '@/views/admin/BackupView.vue' import { useClipboard } from '@/composables/useClipboard' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractApiErrorMessage, extractI18nErrorMessage } from '@/utils/apiError' import { useAppStore } from '@/stores' import { useAdminSettingsStore } from '@/stores/adminSettings' import { @@ -4085,14 +4085,10 @@ const cancelRateLimitModeOptions = computed(() => [ { value: 'fixed', label: t('admin.settings.payment.cancelRateLimitWindowModeFixed') }, ]) -const paymentErrorMap = computed(() => ({ - PENDING_ORDERS: t('payment.errors.PENDING_ORDERS'), -})) - async function loadProviders() { providersLoading.value = true try { const res = await adminAPI.payment.getProviders(); providers.value = res.data || [] } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { providersLoading.value = false } } @@ -4122,7 +4118,7 @@ async function handleSaveProvider(payload: Partial) { // Auto-save settings so provider changes take effect immediately await saveSettings() } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { providerSaving.value = false } @@ -4148,7 +4144,7 @@ async function handleToggleField(provider: ProviderInstance, field: 'enabled' | } else { provider.allow_user_refund = newValue } - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } async function handleToggleType(provider: ProviderInstance, type: string) { @@ -4158,7 +4154,7 @@ async function handleToggleType(provider: ProviderInstance, type: string) { try { await adminAPI.payment.updateProvider(provider.id, { supported_types: updated } as any) provider.supported_types = updated - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } function confirmDeleteProvider(provider: ProviderInstance) { @@ -4177,7 +4173,7 @@ async function handleReorderProviders(updates: { id: number; sort_order: number if (p) p.sort_order = u.sort_order } } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) loadProviders() } } @@ -4189,7 +4185,7 @@ async function handleDeleteProvider() { appStore.showSuccess(t('common.deleted')) showDeleteProviderDialog.value = false loadProviders() - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'), paymentErrorMap.value)) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } onMounted(() => { diff --git a/frontend/src/views/admin/orders/AdminOrdersView.vue b/frontend/src/views/admin/orders/AdminOrdersView.vue index 027c8e5e..dd9fa7e6 100644 --- a/frontend/src/views/admin/orders/AdminOrdersView.vue +++ b/frontend/src/views/admin/orders/AdminOrdersView.vue @@ -116,7 +116,7 @@ import { ref, reactive, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminPaymentAPI } from '@/api/admin/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { formatOrderDateTime } from '@/components/payment/orderUtils' import type { PaymentOrder } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' @@ -167,7 +167,7 @@ async function loadOrders() { orders.value = res.data.items || [] orderPagination.total = res.data.total || 0 } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { ordersLoading.value = false } } @@ -214,12 +214,12 @@ async function showOrderDetail(order: PaymentOrder) { async function handleCancelOrder(order: PaymentOrder) { try { await adminPaymentAPI.cancelOrder(order.id); appStore.showSuccess(t('payment.admin.orderCancelled')); loadOrders() } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } async function handleRetryOrder(order: PaymentOrder) { try { await adminPaymentAPI.retryRecharge(order.id); appStore.showSuccess(t('payment.admin.retrySuccess')); loadOrders() } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } function openRefundDialog(order: PaymentOrder) { selectedOrder.value = order; showRefundDialog.value = true } @@ -230,7 +230,7 @@ async function handleRefund(data: { amount: number; reason: string; deduct_balan try { await adminPaymentAPI.refundOrder(selectedOrder.value.id, { amount: data.amount, reason: data.reason, deduct_balance: data.deduct_balance, force: data.force }) appStore.showSuccess(t('payment.admin.refundSuccess')); showRefundDialog.value = false; loadOrders() - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { refundSubmitting.value = false } } diff --git a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue index 06bc9218..5a80db44 100644 --- a/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue +++ b/frontend/src/views/admin/orders/AdminPaymentDashboardView.vue @@ -72,7 +72,7 @@ import { ref, watch, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminPaymentAPI } from '@/api/admin/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import type { DashboardStats } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' import LoadingSpinner from '@/components/common/LoadingSpinner.vue' @@ -110,7 +110,7 @@ async function loadDashboard() { const res = await adminPaymentAPI.getDashboard(days.value) stats.value = res.data } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { loading.value = false } diff --git a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue index 876b2aa1..c2fc26fe 100644 --- a/frontend/src/views/admin/orders/AdminPaymentPlansView.vue +++ b/frontend/src/views/admin/orders/AdminPaymentPlansView.vue @@ -78,7 +78,7 @@ import { ref, computed, onMounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminPaymentAPI } from '@/api/admin/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import adminAPI from '@/api/admin' import type { SubscriptionPlan } from '@/types/payment' import type { AdminGroup } from '@/types' @@ -150,7 +150,7 @@ async function loadPlans() { : (p.features || []), })) } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { plansLoading.value = false } } @@ -166,7 +166,7 @@ async function toggleForSale(plan: SubscriptionPlan) { await adminPaymentAPI.updatePlan(plan.id, { for_sale: !plan.for_sale }) plan.for_sale = !plan.for_sale } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } @@ -174,7 +174,7 @@ function confirmDeletePlan(plan: SubscriptionPlan) { deletingPlanId.value = plan async function handleDeletePlan() { if (!deletingPlanId.value) return try { await adminPaymentAPI.deletePlan(deletingPlanId.value); appStore.showSuccess(t('common.deleted')); showDeletePlanDialog.value = false; loadPlans() } - catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } } // ==================== Lifecycle ==================== diff --git a/frontend/src/views/user/PaymentQRCodeView.vue b/frontend/src/views/user/PaymentQRCodeView.vue index 0965947a..f844858d 100644 --- a/frontend/src/views/user/PaymentQRCodeView.vue +++ b/frontend/src/views/user/PaymentQRCodeView.vue @@ -39,7 +39,7 @@ import { useRoute, useRouter } from 'vue-router' import AppLayout from '@/components/layout/AppLayout.vue' import { usePaymentStore } from '@/stores/payment' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { useAppStore } from '@/stores' import QRCode from 'qrcode' import alipayIcon from '@/assets/icons/alipay.svg' @@ -167,7 +167,7 @@ async function handleCancel() { cleanup() router.push('/purchase') } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { cancelling.value = false } diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 3f1401b3..e2885c80 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -271,7 +271,7 @@ import { usePaymentStore } from '@/stores/payment' import { useSubscriptionStore } from '@/stores/subscriptions' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { isMobileDevice } from '@/utils/device' import type { SubscriptionPlan, CheckoutInfoResponse, OrderType } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' @@ -610,7 +610,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } else if (apiErr.reason === 'CANCEL_RATE_LIMITED') { errorMessage.value = t('payment.errors.cancelRateLimited') } else { - errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed')) + errorMessage.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed')) } appStore.showError(errorMessage.value) } finally { @@ -648,7 +648,7 @@ onMounted(async () => { } } } - } catch (err: unknown) { appStore.showError(extractApiErrorMessage(err, t('common.error'))) } + } catch (err: unknown) { appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { loading.value = false } // Fetch active subscriptions (uses cache, non-blocking) subscriptionStore.fetchActiveSubscriptions().catch(() => {}) diff --git a/frontend/src/views/user/StripePaymentView.vue b/frontend/src/views/user/StripePaymentView.vue index 20a4a408..3f73d4d5 100644 --- a/frontend/src/views/user/StripePaymentView.vue +++ b/frontend/src/views/user/StripePaymentView.vue @@ -99,7 +99,7 @@ import { useI18n } from 'vue-i18n' import { useRoute, useRouter } from 'vue-router' import { usePaymentStore } from '@/stores/payment' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { isMobileDevice } from '@/utils/device' import type { PaymentOrder } from '@/types/payment' import type { Stripe, StripeElements } from '@stripe/stripe-js' @@ -167,7 +167,7 @@ onMounted(async () => { mountPaymentElement(stripe, clientSecret) } } catch (err: unknown) { - initError.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed')) + initError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed')) } finally { loading.value = false } @@ -248,7 +248,7 @@ async function handleGenericPay() { scheduleClose() } } catch (err: unknown) { - stripeError.value = extractApiErrorMessage(err, t('payment.result.failed')) + stripeError.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.result.failed')) } finally { stripeSubmitting.value = false } diff --git a/frontend/src/views/user/StripePopupView.vue b/frontend/src/views/user/StripePopupView.vue index 2704c62d..688ad644 100644 --- a/frontend/src/views/user/StripePopupView.vue +++ b/frontend/src/views/user/StripePopupView.vue @@ -56,7 +56,7 @@ import { computed, ref, onMounted, onUnmounted } from 'vue' import { useI18n } from 'vue-i18n' import { useRoute } from 'vue-router' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import { isMobileDevice } from '@/utils/device' interface StripeWithWechatPay { @@ -143,7 +143,7 @@ async function initStripe(clientSecret: string, publishableKey: string) { } } } catch (err: unknown) { - error.value = extractApiErrorMessage(err, t('payment.stripeLoadFailed')) + error.value = extractI18nErrorMessage(err, t, 'payment.errors', t('payment.stripeLoadFailed')) } } diff --git a/frontend/src/views/user/UserOrdersView.vue b/frontend/src/views/user/UserOrdersView.vue index ea888eb7..c3ed80eb 100644 --- a/frontend/src/views/user/UserOrdersView.vue +++ b/frontend/src/views/user/UserOrdersView.vue @@ -86,7 +86,7 @@ import { useI18n } from 'vue-i18n' import { useRouter } from 'vue-router' import { useAppStore } from '@/stores' import { paymentAPI } from '@/api/payment' -import { extractApiErrorMessage } from '@/utils/apiError' +import { extractI18nErrorMessage } from '@/utils/apiError' import type { PaymentOrder } from '@/types/payment' import AppLayout from '@/components/layout/AppLayout.vue' import Pagination from '@/components/common/Pagination.vue' @@ -128,7 +128,7 @@ async function fetchOrders() { orders.value = res.data.items || [] pagination.total = res.data.total || 0 } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { loading.value = false } @@ -148,7 +148,7 @@ async function confirmCancel() { cancelTargetId.value = null await fetchOrders() } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { actionLoading.value = false } @@ -166,7 +166,7 @@ async function confirmRefund() { refundReason.value = '' await fetchOrders() } catch (err: unknown) { - appStore.showError(extractApiErrorMessage(err, t('common.error'))) + appStore.showError(extractI18nErrorMessage(err, t, 'payment.errors', t('common.error'))) } finally { actionLoading.value = false } -- GitLab From 97c9b992cbf8b658b6ef27c27fd0041893b74317 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:27:15 +0800 Subject: [PATCH 049/265] fix: require wechat unionid for oauth identity --- backend/internal/handler/auth_wechat_oauth.go | 6 +- .../handler/auth_wechat_oauth_test.go | 107 +++++++++++------- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 816f60fd..f0755f1f 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -193,11 +193,11 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { unionid := strings.TrimSpace(firstNonEmpty(userInfo.UnionID, tokenResp.UnionID)) openid := strings.TrimSpace(firstNonEmpty(userInfo.OpenID, tokenResp.OpenID)) - providerSubject := firstNonEmpty(unionid, openid) - if providerSubject == "" { - redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_subject", "") + if unionid == "" { + redirectOAuthError(c, frontendCallback, "provider_error", "wechat_missing_unionid", "") return } + providerSubject := unionid username := firstNonEmpty(userInfo.Nickname, wechatFallbackUsername(providerSubject)) email := wechatSyntheticEmail(providerSubject) diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 0d1df1b6..1ff80e1b 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -6,7 +6,6 @@ import ( "bytes" "context" "database/sql" - "encoding/base64" "net/http" "net/http/httptest" "net/url" @@ -122,6 +121,59 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { require.Equal(t, "openid-123", session.UpstreamIdentityClaims["openid"]) } +func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","nickname":"WeChat Nick","headimgurl":"https://cdn.example/avatar.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Contains(t, recorder.Header().Get("Location"), "#error=provider_error") + require.Contains(t, recorder.Header().Get("Location"), "error_message=wechat_missing_unionid") + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + + count, err := client.PendingAuthSession.Query().Count(context.Background()) + require.NoError(t, err) + require.Zero(t, count) +} + func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *testing.T) { testCases := []struct { name string @@ -542,12 +594,7 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl userRepo := &oauthPendingFlowUserRepo{client: client} redeemRepo := repository.NewRedeemCodeRepository(client) - settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ - values: map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), - }, - }, &config.Config{ + cfg := &config.Config{ JWT: config.JWTConfig{ Secret: "test-secret", ExpireHour: 1, @@ -558,25 +605,20 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl UserBalance: 0, UserConcurrency: 1, }, - }) + } + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ + values: map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + }, + }, cfg) authSvc := service.NewAuthService( client, userRepo, redeemRepo, &wechatOAuthRefreshTokenCacheStub{}, - &config.Config{ - JWT: config.JWTConfig{ - Secret: "test-secret", - ExpireHour: 1, - AccessTokenExpireMinutes: 60, - RefreshTokenExpireDays: 7, - }, - Default: config.DefaultConfig{ - UserBalance: 0, - UserConcurrency: 1, - }, - }, + cfg, settingSvc, nil, nil, @@ -588,33 +630,10 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl return &AuthHandler{ authService: authSvc, settingSvc: settingSvc, + cfg: cfg, }, client } -func encodedCookie(name, value string) *http.Cookie { - return &http.Cookie{ - Name: name, - Value: encodeCookieValue(value), - Path: "/", - } -} - -func findCookie(cookies []*http.Cookie, name string) *http.Cookie { - for _, cookie := range cookies { - if cookie.Name == name { - return cookie - } - } - return nil -} - -func decodeCookieValueForTest(t *testing.T, value string) string { - t.Helper() - raw, err := base64.RawURLEncoding.DecodeString(value) - require.NoError(t, err) - return string(raw) -} - func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { t.Helper() -- GitLab From 7c7924e9fa6125f47e2b080496466ab22cec40ad Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:31:19 +0800 Subject: [PATCH 050/265] fix: guard payment fulfillment provider mismatch --- .../internal/service/payment_fulfillment.go | 25 ++++++++ .../service/payment_fulfillment_test.go | 64 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 44818b37..519455f0 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -41,6 +41,19 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo slog.Error("order not found", "orderID", oid) return nil } + instanceProviderKey := "" + if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { + instanceProviderKey = inst.ProviderKey + } + expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, instanceProviderKey) + if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { + s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ + "expectedProvider": expectedProviderKey, + "actualProvider": pk, + "tradeNo": tradeNo, + }) + return fmt.Errorf("provider mismatch: expected %s, got %s", expectedProviderKey, pk) + } // Skip amount check when paid=0 (e.g. QueryOrder doesn't return amount). // Also skip if paid is NaN/Inf (malformed provider data). if paid > 0 && !math.IsNaN(paid) && !math.IsInf(paid, 0) { @@ -56,6 +69,18 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo return s.toPaid(ctx, o, tradeNo, paid, pk) } +func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, instanceProviderKey string) string { + if key := strings.TrimSpace(instanceProviderKey); key != "" { + return key + } + if registry != nil { + if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" { + return key + } + } + return strings.TrimSpace(orderPaymentType) +} + func (s *PaymentService) toPaid(ctx context.Context, o *dbent.PaymentOrder, tradeNo string, paid float64, pk string) error { previousStatus := o.Status now := time.Now() diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 625b0d9f..4cc00301 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -3,12 +3,37 @@ package service import ( + "context" "errors" "testing" + "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/assert" ) +type paymentFulfillmentTestProvider struct { + key string + supportedTypes []payment.PaymentType +} + +func (p paymentFulfillmentTestProvider) Name() string { return p.key } +func (p paymentFulfillmentTestProvider) ProviderKey() string { return p.key } +func (p paymentFulfillmentTestProvider) SupportedTypes() []payment.PaymentType { + return p.supportedTypes +} +func (p paymentFulfillmentTestProvider) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) QueryOrder(ctx context.Context, tradeNo string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) VerifyNotification(ctx context.Context, rawBody string, headers map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + // --------------------------------------------------------------------------- // resolveRedeemAction — pure idempotency decision logic // --------------------------------------------------------------------------- @@ -161,3 +186,42 @@ func TestResolveRedeemAction_IsUsedCanUseConsistency(t *testing.T) { assert.True(t, unusedCode.CanUse()) assert.Equal(t, redeemActionRedeem, resolveRedeemAction(unusedCode, nil)) } + +func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay), + ) +} + +func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeEasyPay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, ""), + ) +} + +func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) { + t.Parallel() + + assert.Equal(t, + payment.TypeWxpay, + expectedNotificationProviderKey(nil, payment.TypeWxpay, ""), + ) +} -- GitLab From e3f69e02464fa79153aa16db7a457359f6a9a8a3 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:42:01 +0800 Subject: [PATCH 051/265] fix: tighten webhook provider resolution --- backend/internal/service/payment_service.go | 20 --- .../service/payment_webhook_provider.go | 86 +++++++++++ .../service/payment_webhook_provider_test.go | 141 ++++++++++++++++++ 3 files changed, 227 insertions(+), 20 deletions(-) create mode 100644 backend/internal/service/payment_webhook_provider.go create mode 100644 backend/internal/service/payment_webhook_provider_test.go diff --git a/backend/internal/service/payment_service.go b/backend/internal/service/payment_service.go index e897741a..d3175ba6 100644 --- a/backend/internal/service/payment_service.go +++ b/backend/internal/service/payment_service.go @@ -9,7 +9,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -225,25 +224,6 @@ func (s *PaymentService) loadProviders(ctx context.Context) { } } -// GetWebhookProvider returns the provider instance that should verify a webhook. -// It extracts out_trade_no from the raw body, looks up the order to find the -// original provider instance, and creates a provider with that instance's credentials. -// Falls back to the registry provider when the order cannot be found. -func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { - if outTradeNo != "" { - order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) - if err == nil { - p, pErr := s.getOrderProvider(ctx, order) - if pErr == nil { - return p, nil - } - slog.Warn("[Webhook] order provider creation failed, falling back to registry", "outTradeNo", outTradeNo, "error", pErr) - } - } - s.EnsureProviders(ctx) - return s.registry.GetProviderByKey(providerKey) -} - // --- Helpers --- func psIsRefundStatus(s string) bool { diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go new file mode 100644 index 00000000..a877db2b --- /dev/null +++ b/backend/internal/service/payment_webhook_provider.go @@ -0,0 +1,86 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/Wei-Shaw/sub2api/internal/payment/provider" +) + +// GetWebhookProvider returns the provider instance that should verify a webhook. +// It resolves the original provider instance from the order whenever possible and +// only falls back to a registry provider for legacy/single-instance scenarios. +func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, outTradeNo string) (payment.Provider, error) { + if outTradeNo != "" { + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(outTradeNo)).Only(ctx) + if err == nil { + if psHasPinnedProviderInstance(order) { + return s.getPinnedOrderProvider(ctx, order) + } + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + s.EnsureProviders(ctx) + return s.registry.GetProviderByKey(providerKey) + } + } + + if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { + return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) + } + + s.EnsureProviders(ctx) + return s.registry.GetProviderByKey(providerKey) +} + +func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst == nil { + return nil, fmt.Errorf("order %d provider instance is missing", o.ID) + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create pinned provider: %w", err) + } + return prov, nil +} + +func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { + providerKey = strings.TrimSpace(providerKey) + if providerKey == "" || s == nil || s.entClient == nil { + return false + } + + count, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.ProviderKeyEQ(providerKey), + paymentproviderinstance.EnabledEQ(true), + ). + Count(ctx) + if err != nil { + slog.Warn("payment webhook fallback instance count failed", "provider", providerKey, "error", err) + return false + } + return count <= 1 +} + +func psHasPinnedProviderInstance(order *dbent.PaymentOrder) bool { + return order != nil && order.ProviderInstanceID != nil && strings.TrimSpace(*order.ProviderInstanceID) != "" +} diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go new file mode 100644 index 00000000..85c296de --- /dev/null +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -0,0 +1,141 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +type webhookProviderTestDouble struct { + key string + types []payment.PaymentType +} + +func (p webhookProviderTestDouble) Name() string { return p.key } +func (p webhookProviderTestDouble) ProviderKey() string { return p.key } +func (p webhookProviderTestDouble) SupportedTypes() []payment.PaymentType { return p.types } +func (p webhookProviderTestDouble) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) QueryOrder(context.Context, string) (*payment.QueryOrderResponse, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) VerifyNotification(context.Context, string, map[string]string) (*payment.PaymentNotification, error) { + panic("unexpected call") +} +func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest) (*payment.RefundResponse, error) { + panic("unexpected call") +} + +func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-b"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + registry: payment.NewRegistry(), + providersLoaded: true, + } + + _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "") + require.Error(t, err) + require.Contains(t, err.Error(), "ambiguous") +} + +func TestGetWebhookProviderAllowsSingleInstanceRegistryFallback(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeStripe, + types: []payment.PaymentType{payment.TypeStripe}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + prov, err := svc.GetWebhookProvider(ctx, payment.TypeStripe, "") + require.NoError(t, err) + require.Equal(t, payment.TypeStripe, prov.ProviderKey()) +} + +func TestGetWebhookProviderRejectsRegistryFallbackForPinnedOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("webhook@example.com"). + SetPasswordHash("hash"). + SetUsername("webhook"). + Save(ctx) + require.NoError(t, err) + + pinnedInstanceID := "999" + _, err = client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("TEST-RECHARGE"). + SetOutTradeNo("sub2_test_pinned_order"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(pinnedInstanceID). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + registry.Register(webhookProviderTestDouble{ + key: payment.TypeWxpay, + types: []payment.PaymentType{payment.TypeWxpay}, + }) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + _, err = svc.GetWebhookProvider(ctx, payment.TypeWxpay, "sub2_test_pinned_order") + require.Error(t, err) + require.Contains(t, err.Error(), "provider instance") +} -- GitLab From c0b24aefba926c61d749af11cd01cd6f96bb65fe Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:47:14 +0800 Subject: [PATCH 052/265] feat: snapshot payment provider keys on orders --- backend/ent/migrate/schema.go | 15 ++-- backend/ent/mutation.go | 75 ++++++++++++++++- backend/ent/paymentorder.go | 16 +++- backend/ent/paymentorder/paymentorder.go | 10 +++ backend/ent/paymentorder/where.go | 80 ++++++++++++++++++ backend/ent/paymentorder_create.go | 83 +++++++++++++++++++ backend/ent/paymentorder_update.go | 62 ++++++++++++++ backend/ent/runtime/runtime.go | 20 +++-- backend/ent/schema/payment_order.go | 4 + .../internal/service/payment_fulfillment.go | 7 +- .../service/payment_fulfillment_test.go | 21 ++++- backend/internal/service/payment_order.go | 8 +- .../service/payment_order_lifecycle.go | 13 ++- ...dd_payment_order_provider_key_snapshot.sql | 10 +++ 14 files changed, 400 insertions(+), 24 deletions(-) create mode 100644 backend/migrations/112_add_payment_order_provider_key_snapshot.sql diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index bf41e73b..230ea060 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -654,6 +654,7 @@ var ( {Name: "subscription_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "subscription_days", Type: field.TypeInt, Nullable: true}, {Name: "provider_instance_id", Type: field.TypeString, Nullable: true, Size: 64}, + {Name: "provider_key", Type: field.TypeString, Nullable: true, Size: 30}, {Name: "status", Type: field.TypeString, Size: 30, Default: "PENDING"}, {Name: "refund_amount", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,2)"}}, {Name: "refund_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, @@ -682,7 +683,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "payment_orders_users_payment_orders", - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[38]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -696,32 +697,32 @@ var ( { Name: "paymentorder_user_id", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[37]}, + Columns: []*schema.Column{PaymentOrdersColumns[38]}, }, { Name: "paymentorder_status", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[19]}, + Columns: []*schema.Column{PaymentOrdersColumns[20]}, }, { Name: "paymentorder_expires_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[27]}, + Columns: []*schema.Column{PaymentOrdersColumns[28]}, }, { Name: "paymentorder_created_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[35]}, + Columns: []*schema.Column{PaymentOrdersColumns[36]}, }, { Name: "paymentorder_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[29]}, }, { Name: "paymentorder_payment_type_paid_at", Unique: false, - Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[28]}, + Columns: []*schema.Column{PaymentOrdersColumns[9], PaymentOrdersColumns[29]}, }, { Name: "paymentorder_order_type", diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 12905c9a..5227015c 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -15385,6 +15385,7 @@ type PaymentOrderMutation struct { subscription_days *int addsubscription_days *int provider_instance_id *string + provider_key *string status *string refund_amount *float64 addrefund_amount *float64 @@ -16421,6 +16422,55 @@ func (m *PaymentOrderMutation) ResetProviderInstanceID() { delete(m.clearedFields, paymentorder.FieldProviderInstanceID) } +// SetProviderKey sets the "provider_key" field. +func (m *PaymentOrderMutation) SetProviderKey(s string) { + m.provider_key = &s +} + +// ProviderKey returns the value of the "provider_key" field in the mutation. +func (m *PaymentOrderMutation) ProviderKey() (r string, exists bool) { + v := m.provider_key + if v == nil { + return + } + return *v, true +} + +// OldProviderKey returns the old "provider_key" field's value of the PaymentOrder entity. +// If the PaymentOrder object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentOrderMutation) OldProviderKey(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldProviderKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldProviderKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldProviderKey: %w", err) + } + return oldValue.ProviderKey, nil +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (m *PaymentOrderMutation) ClearProviderKey() { + m.provider_key = nil + m.clearedFields[paymentorder.FieldProviderKey] = struct{}{} +} + +// ProviderKeyCleared returns if the "provider_key" field was cleared in this mutation. +func (m *PaymentOrderMutation) ProviderKeyCleared() bool { + _, ok := m.clearedFields[paymentorder.FieldProviderKey] + return ok +} + +// ResetProviderKey resets all changes to the "provider_key" field. +func (m *PaymentOrderMutation) ResetProviderKey() { + m.provider_key = nil + delete(m.clearedFields, paymentorder.FieldProviderKey) +} + // SetStatus sets the "status" field. func (m *PaymentOrderMutation) SetStatus(s string) { m.status = &s @@ -17280,7 +17330,7 @@ func (m *PaymentOrderMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PaymentOrderMutation) Fields() []string { - fields := make([]string, 0, 37) + fields := make([]string, 0, 38) if m.user != nil { fields = append(fields, paymentorder.FieldUserID) } @@ -17338,6 +17388,9 @@ func (m *PaymentOrderMutation) Fields() []string { if m.provider_instance_id != nil { fields = append(fields, paymentorder.FieldProviderInstanceID) } + if m.provider_key != nil { + fields = append(fields, paymentorder.FieldProviderKey) + } if m.status != nil { fields = append(fields, paymentorder.FieldStatus) } @@ -17438,6 +17491,8 @@ func (m *PaymentOrderMutation) Field(name string) (ent.Value, bool) { return m.SubscriptionDays() case paymentorder.FieldProviderInstanceID: return m.ProviderInstanceID() + case paymentorder.FieldProviderKey: + return m.ProviderKey() case paymentorder.FieldStatus: return m.Status() case paymentorder.FieldRefundAmount: @@ -17521,6 +17576,8 @@ func (m *PaymentOrderMutation) OldField(ctx context.Context, name string) (ent.V return m.OldSubscriptionDays(ctx) case paymentorder.FieldProviderInstanceID: return m.OldProviderInstanceID(ctx) + case paymentorder.FieldProviderKey: + return m.OldProviderKey(ctx) case paymentorder.FieldStatus: return m.OldStatus(ctx) case paymentorder.FieldRefundAmount: @@ -17699,6 +17756,13 @@ func (m *PaymentOrderMutation) SetField(name string, value ent.Value) error { } m.SetProviderInstanceID(v) return nil + case paymentorder.FieldProviderKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetProviderKey(v) + return nil case paymentorder.FieldStatus: v, ok := value.(string) if !ok { @@ -17966,6 +18030,9 @@ func (m *PaymentOrderMutation) ClearedFields() []string { if m.FieldCleared(paymentorder.FieldProviderInstanceID) { fields = append(fields, paymentorder.FieldProviderInstanceID) } + if m.FieldCleared(paymentorder.FieldProviderKey) { + fields = append(fields, paymentorder.FieldProviderKey) + } if m.FieldCleared(paymentorder.FieldRefundReason) { fields = append(fields, paymentorder.FieldRefundReason) } @@ -18034,6 +18101,9 @@ func (m *PaymentOrderMutation) ClearField(name string) error { case paymentorder.FieldProviderInstanceID: m.ClearProviderInstanceID() return nil + case paymentorder.FieldProviderKey: + m.ClearProviderKey() + return nil case paymentorder.FieldRefundReason: m.ClearRefundReason() return nil @@ -18129,6 +18199,9 @@ func (m *PaymentOrderMutation) ResetField(name string) error { case paymentorder.FieldProviderInstanceID: m.ResetProviderInstanceID() return nil + case paymentorder.FieldProviderKey: + m.ResetProviderKey() + return nil case paymentorder.FieldStatus: m.ResetStatus() return nil diff --git a/backend/ent/paymentorder.go b/backend/ent/paymentorder.go index 6ea3e709..a58823ee 100644 --- a/backend/ent/paymentorder.go +++ b/backend/ent/paymentorder.go @@ -56,6 +56,8 @@ type PaymentOrder struct { SubscriptionDays *int `json:"subscription_days,omitempty"` // ProviderInstanceID holds the value of the "provider_instance_id" field. ProviderInstanceID *string `json:"provider_instance_id,omitempty"` + // ProviderKey holds the value of the "provider_key" field. + ProviderKey *string `json:"provider_key,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` // RefundAmount holds the value of the "refund_amount" field. @@ -129,7 +131,7 @@ func (*PaymentOrder) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case paymentorder.FieldID, paymentorder.FieldUserID, paymentorder.FieldPlanID, paymentorder.FieldSubscriptionGroupID, paymentorder.FieldSubscriptionDays: values[i] = new(sql.NullInt64) - case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: + case paymentorder.FieldUserEmail, paymentorder.FieldUserName, paymentorder.FieldUserNotes, paymentorder.FieldRechargeCode, paymentorder.FieldOutTradeNo, paymentorder.FieldPaymentType, paymentorder.FieldPaymentTradeNo, paymentorder.FieldPayURL, paymentorder.FieldQrCode, paymentorder.FieldQrCodeImg, paymentorder.FieldOrderType, paymentorder.FieldProviderInstanceID, paymentorder.FieldProviderKey, paymentorder.FieldStatus, paymentorder.FieldRefundReason, paymentorder.FieldRefundRequestReason, paymentorder.FieldRefundRequestedBy, paymentorder.FieldFailedReason, paymentorder.FieldClientIP, paymentorder.FieldSrcHost, paymentorder.FieldSrcURL: values[i] = new(sql.NullString) case paymentorder.FieldRefundAt, paymentorder.FieldRefundRequestedAt, paymentorder.FieldExpiresAt, paymentorder.FieldPaidAt, paymentorder.FieldCompletedAt, paymentorder.FieldFailedAt, paymentorder.FieldCreatedAt, paymentorder.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -276,6 +278,13 @@ func (_m *PaymentOrder) assignValues(columns []string, values []any) error { _m.ProviderInstanceID = new(string) *_m.ProviderInstanceID = value.String } + case paymentorder.FieldProviderKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field provider_key", values[i]) + } else if value.Valid { + _m.ProviderKey = new(string) + *_m.ProviderKey = value.String + } case paymentorder.FieldStatus: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field status", values[i]) @@ -508,6 +517,11 @@ func (_m *PaymentOrder) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.ProviderKey; v != nil { + builder.WriteString("provider_key=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") diff --git a/backend/ent/paymentorder/paymentorder.go b/backend/ent/paymentorder/paymentorder.go index 4467b2b6..af9b1422 100644 --- a/backend/ent/paymentorder/paymentorder.go +++ b/backend/ent/paymentorder/paymentorder.go @@ -52,6 +52,8 @@ const ( FieldSubscriptionDays = "subscription_days" // FieldProviderInstanceID holds the string denoting the provider_instance_id field in the database. FieldProviderInstanceID = "provider_instance_id" + // FieldProviderKey holds the string denoting the provider_key field in the database. + FieldProviderKey = "provider_key" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" // FieldRefundAmount holds the string denoting the refund_amount field in the database. @@ -123,6 +125,7 @@ var Columns = []string{ FieldSubscriptionGroupID, FieldSubscriptionDays, FieldProviderInstanceID, + FieldProviderKey, FieldStatus, FieldRefundAmount, FieldRefundReason, @@ -176,6 +179,8 @@ var ( OrderTypeValidator func(string) error // ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. ProviderInstanceIDValidator func(string) error + // ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + ProviderKeyValidator func(string) error // DefaultStatus holds the default value on creation for the "status" field. DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. @@ -301,6 +306,11 @@ func ByProviderInstanceID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldProviderInstanceID, opts...).ToFunc() } +// ByProviderKey orders the results by the provider_key field. +func ByProviderKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldProviderKey, opts...).ToFunc() +} + // ByStatus orders the results by the status field. func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() diff --git a/backend/ent/paymentorder/where.go b/backend/ent/paymentorder/where.go index 78520fac..0f6b74a0 100644 --- a/backend/ent/paymentorder/where.go +++ b/backend/ent/paymentorder/where.go @@ -150,6 +150,11 @@ func ProviderInstanceID(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldProviderInstanceID, v)) } +// ProviderKey applies equality check predicate on the "provider_key" field. It's identical to ProviderKeyEQ. +func ProviderKey(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + // Status applies equality check predicate on the "status" field. It's identical to StatusEQ. func Status(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) @@ -1360,6 +1365,81 @@ func ProviderInstanceIDContainsFold(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderInstanceID, v)) } +// ProviderKeyEQ applies the EQ predicate on the "provider_key" field. +func ProviderKeyEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEQ(FieldProviderKey, v)) +} + +// ProviderKeyNEQ applies the NEQ predicate on the "provider_key" field. +func ProviderKeyNEQ(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNEQ(FieldProviderKey, v)) +} + +// ProviderKeyIn applies the In predicate on the "provider_key" field. +func ProviderKeyIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIn(FieldProviderKey, vs...)) +} + +// ProviderKeyNotIn applies the NotIn predicate on the "provider_key" field. +func ProviderKeyNotIn(vs ...string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotIn(FieldProviderKey, vs...)) +} + +// ProviderKeyGT applies the GT predicate on the "provider_key" field. +func ProviderKeyGT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGT(FieldProviderKey, v)) +} + +// ProviderKeyGTE applies the GTE predicate on the "provider_key" field. +func ProviderKeyGTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldGTE(FieldProviderKey, v)) +} + +// ProviderKeyLT applies the LT predicate on the "provider_key" field. +func ProviderKeyLT(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLT(FieldProviderKey, v)) +} + +// ProviderKeyLTE applies the LTE predicate on the "provider_key" field. +func ProviderKeyLTE(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldLTE(FieldProviderKey, v)) +} + +// ProviderKeyContains applies the Contains predicate on the "provider_key" field. +func ProviderKeyContains(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContains(FieldProviderKey, v)) +} + +// ProviderKeyHasPrefix applies the HasPrefix predicate on the "provider_key" field. +func ProviderKeyHasPrefix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasPrefix(FieldProviderKey, v)) +} + +// ProviderKeyHasSuffix applies the HasSuffix predicate on the "provider_key" field. +func ProviderKeyHasSuffix(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldHasSuffix(FieldProviderKey, v)) +} + +// ProviderKeyIsNil applies the IsNil predicate on the "provider_key" field. +func ProviderKeyIsNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldIsNull(FieldProviderKey)) +} + +// ProviderKeyNotNil applies the NotNil predicate on the "provider_key" field. +func ProviderKeyNotNil() predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldNotNull(FieldProviderKey)) +} + +// ProviderKeyEqualFold applies the EqualFold predicate on the "provider_key" field. +func ProviderKeyEqualFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldEqualFold(FieldProviderKey, v)) +} + +// ProviderKeyContainsFold applies the ContainsFold predicate on the "provider_key" field. +func ProviderKeyContainsFold(v string) predicate.PaymentOrder { + return predicate.PaymentOrder(sql.FieldContainsFold(FieldProviderKey, v)) +} + // StatusEQ applies the EQ predicate on the "status" field. func StatusEQ(v string) predicate.PaymentOrder { return predicate.PaymentOrder(sql.FieldEQ(FieldStatus, v)) diff --git a/backend/ent/paymentorder_create.go b/backend/ent/paymentorder_create.go index 03098339..497ba52c 100644 --- a/backend/ent/paymentorder_create.go +++ b/backend/ent/paymentorder_create.go @@ -225,6 +225,20 @@ func (_c *PaymentOrderCreate) SetNillableProviderInstanceID(v *string) *PaymentO return _c } +// SetProviderKey sets the "provider_key" field. +func (_c *PaymentOrderCreate) SetProviderKey(v string) *PaymentOrderCreate { + _c.mutation.SetProviderKey(v) + return _c +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_c *PaymentOrderCreate) SetNillableProviderKey(v *string) *PaymentOrderCreate { + if v != nil { + _c.SetProviderKey(*v) + } + return _c +} + // SetStatus sets the "status" field. func (_c *PaymentOrderCreate) SetStatus(v string) *PaymentOrderCreate { _c.mutation.SetStatus(v) @@ -602,6 +616,11 @@ func (_c *PaymentOrderCreate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _c.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if _, ok := _c.mutation.Status(); !ok { return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "PaymentOrder.status"`)} } @@ -748,6 +767,10 @@ func (_c *PaymentOrderCreate) createSpec() (*PaymentOrder, *sqlgraph.CreateSpec) _spec.SetField(paymentorder.FieldProviderInstanceID, field.TypeString, value) _node.ProviderInstanceID = &value } + if value, ok := _c.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + _node.ProviderKey = &value + } if value, ok := _c.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) _node.Status = value @@ -1201,6 +1224,24 @@ func (u *PaymentOrderUpsert) ClearProviderInstanceID() *PaymentOrderUpsert { return u } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsert) SetProviderKey(v string) *PaymentOrderUpsert { + u.Set(paymentorder.FieldProviderKey, v) + return u +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsert) UpdateProviderKey() *PaymentOrderUpsert { + u.SetExcluded(paymentorder.FieldProviderKey) + return u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsert) ClearProviderKey() *PaymentOrderUpsert { + u.SetNull(paymentorder.FieldProviderKey) + return u +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsert) SetStatus(v string) *PaymentOrderUpsert { u.Set(paymentorder.FieldStatus, v) @@ -1880,6 +1921,27 @@ func (u *PaymentOrderUpsertOne) ClearProviderInstanceID() *PaymentOrderUpsertOne }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertOne) SetProviderKey(v string) *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertOne) UpdateProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertOne) ClearProviderKey() *PaymentOrderUpsertOne { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertOne) SetStatus(v string) *PaymentOrderUpsertOne { return u.Update(func(s *PaymentOrderUpsert) { @@ -2770,6 +2832,27 @@ func (u *PaymentOrderUpsertBulk) ClearProviderInstanceID() *PaymentOrderUpsertBu }) } +// SetProviderKey sets the "provider_key" field. +func (u *PaymentOrderUpsertBulk) SetProviderKey(v string) *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.SetProviderKey(v) + }) +} + +// UpdateProviderKey sets the "provider_key" field to the value that was provided on create. +func (u *PaymentOrderUpsertBulk) UpdateProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.UpdateProviderKey() + }) +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (u *PaymentOrderUpsertBulk) ClearProviderKey() *PaymentOrderUpsertBulk { + return u.Update(func(s *PaymentOrderUpsert) { + s.ClearProviderKey() + }) +} + // SetStatus sets the "status" field. func (u *PaymentOrderUpsertBulk) SetStatus(v string) *PaymentOrderUpsertBulk { return u.Update(func(s *PaymentOrderUpsert) { diff --git a/backend/ent/paymentorder_update.go b/backend/ent/paymentorder_update.go index 5978fc29..9a901415 100644 --- a/backend/ent/paymentorder_update.go +++ b/backend/ent/paymentorder_update.go @@ -385,6 +385,26 @@ func (_u *PaymentOrderUpdate) ClearProviderInstanceID() *PaymentOrderUpdate { return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdate) SetProviderKey(v string) *PaymentOrderUpdate { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdate) SetNillableProviderKey(v *string) *PaymentOrderUpdate { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdate) ClearProviderKey() *PaymentOrderUpdate { + _u.mutation.ClearProviderKey() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdate) SetStatus(v string) *PaymentOrderUpdate { _u.mutation.SetStatus(v) @@ -776,6 +796,11 @@ func (_u *PaymentOrderUpdate) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -910,6 +935,12 @@ func (_u *PaymentOrderUpdate) sqlSave(ctx context.Context) (_node int, err error if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } @@ -1399,6 +1430,26 @@ func (_u *PaymentOrderUpdateOne) ClearProviderInstanceID() *PaymentOrderUpdateOn return _u } +// SetProviderKey sets the "provider_key" field. +func (_u *PaymentOrderUpdateOne) SetProviderKey(v string) *PaymentOrderUpdateOne { + _u.mutation.SetProviderKey(v) + return _u +} + +// SetNillableProviderKey sets the "provider_key" field if the given value is not nil. +func (_u *PaymentOrderUpdateOne) SetNillableProviderKey(v *string) *PaymentOrderUpdateOne { + if v != nil { + _u.SetProviderKey(*v) + } + return _u +} + +// ClearProviderKey clears the value of the "provider_key" field. +func (_u *PaymentOrderUpdateOne) ClearProviderKey() *PaymentOrderUpdateOne { + _u.mutation.ClearProviderKey() + return _u +} + // SetStatus sets the "status" field. func (_u *PaymentOrderUpdateOne) SetStatus(v string) *PaymentOrderUpdateOne { _u.mutation.SetStatus(v) @@ -1803,6 +1854,11 @@ func (_u *PaymentOrderUpdateOne) check() error { return &ValidationError{Name: "provider_instance_id", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_instance_id": %w`, err)} } } + if v, ok := _u.mutation.ProviderKey(); ok { + if err := paymentorder.ProviderKeyValidator(v); err != nil { + return &ValidationError{Name: "provider_key", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.provider_key": %w`, err)} + } + } if v, ok := _u.mutation.Status(); ok { if err := paymentorder.StatusValidator(v); err != nil { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "PaymentOrder.status": %w`, err)} @@ -1954,6 +2010,12 @@ func (_u *PaymentOrderUpdateOne) sqlSave(ctx context.Context) (_node *PaymentOrd if _u.mutation.ProviderInstanceIDCleared() { _spec.ClearField(paymentorder.FieldProviderInstanceID, field.TypeString) } + if value, ok := _u.mutation.ProviderKey(); ok { + _spec.SetField(paymentorder.FieldProviderKey, field.TypeString, value) + } + if _u.mutation.ProviderKeyCleared() { + _spec.ClearField(paymentorder.FieldProviderKey, field.TypeString) + } if value, ok := _u.mutation.Status(); ok { _spec.SetField(paymentorder.FieldStatus, field.TypeString, value) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 268e9ddb..b7118ac9 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -723,38 +723,42 @@ func init() { paymentorderDescProviderInstanceID := paymentorderFields[18].Descriptor() // paymentorder.ProviderInstanceIDValidator is a validator for the "provider_instance_id" field. It is called by the builders before save. paymentorder.ProviderInstanceIDValidator = paymentorderDescProviderInstanceID.Validators[0].(func(string) error) + // paymentorderDescProviderKey is the schema descriptor for provider_key field. + paymentorderDescProviderKey := paymentorderFields[19].Descriptor() + // paymentorder.ProviderKeyValidator is a validator for the "provider_key" field. It is called by the builders before save. + paymentorder.ProviderKeyValidator = paymentorderDescProviderKey.Validators[0].(func(string) error) // paymentorderDescStatus is the schema descriptor for status field. - paymentorderDescStatus := paymentorderFields[19].Descriptor() + paymentorderDescStatus := paymentorderFields[20].Descriptor() // paymentorder.DefaultStatus holds the default value on creation for the status field. paymentorder.DefaultStatus = paymentorderDescStatus.Default.(string) // paymentorder.StatusValidator is a validator for the "status" field. It is called by the builders before save. paymentorder.StatusValidator = paymentorderDescStatus.Validators[0].(func(string) error) // paymentorderDescRefundAmount is the schema descriptor for refund_amount field. - paymentorderDescRefundAmount := paymentorderFields[20].Descriptor() + paymentorderDescRefundAmount := paymentorderFields[21].Descriptor() // paymentorder.DefaultRefundAmount holds the default value on creation for the refund_amount field. paymentorder.DefaultRefundAmount = paymentorderDescRefundAmount.Default.(float64) // paymentorderDescForceRefund is the schema descriptor for force_refund field. - paymentorderDescForceRefund := paymentorderFields[23].Descriptor() + paymentorderDescForceRefund := paymentorderFields[24].Descriptor() // paymentorder.DefaultForceRefund holds the default value on creation for the force_refund field. paymentorder.DefaultForceRefund = paymentorderDescForceRefund.Default.(bool) // paymentorderDescRefundRequestedBy is the schema descriptor for refund_requested_by field. - paymentorderDescRefundRequestedBy := paymentorderFields[26].Descriptor() + paymentorderDescRefundRequestedBy := paymentorderFields[27].Descriptor() // paymentorder.RefundRequestedByValidator is a validator for the "refund_requested_by" field. It is called by the builders before save. paymentorder.RefundRequestedByValidator = paymentorderDescRefundRequestedBy.Validators[0].(func(string) error) // paymentorderDescClientIP is the schema descriptor for client_ip field. - paymentorderDescClientIP := paymentorderFields[32].Descriptor() + paymentorderDescClientIP := paymentorderFields[33].Descriptor() // paymentorder.ClientIPValidator is a validator for the "client_ip" field. It is called by the builders before save. paymentorder.ClientIPValidator = paymentorderDescClientIP.Validators[0].(func(string) error) // paymentorderDescSrcHost is the schema descriptor for src_host field. - paymentorderDescSrcHost := paymentorderFields[33].Descriptor() + paymentorderDescSrcHost := paymentorderFields[34].Descriptor() // paymentorder.SrcHostValidator is a validator for the "src_host" field. It is called by the builders before save. paymentorder.SrcHostValidator = paymentorderDescSrcHost.Validators[0].(func(string) error) // paymentorderDescCreatedAt is the schema descriptor for created_at field. - paymentorderDescCreatedAt := paymentorderFields[35].Descriptor() + paymentorderDescCreatedAt := paymentorderFields[36].Descriptor() // paymentorder.DefaultCreatedAt holds the default value on creation for the created_at field. paymentorder.DefaultCreatedAt = paymentorderDescCreatedAt.Default.(func() time.Time) // paymentorderDescUpdatedAt is the schema descriptor for updated_at field. - paymentorderDescUpdatedAt := paymentorderFields[36].Descriptor() + paymentorderDescUpdatedAt := paymentorderFields[37].Descriptor() // paymentorder.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentorder.DefaultUpdatedAt = paymentorderDescUpdatedAt.Default.(func() time.Time) // paymentorder.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/ent/schema/payment_order.go b/backend/ent/schema/payment_order.go index a9576d2a..64378de1 100644 --- a/backend/ent/schema/payment_order.go +++ b/backend/ent/schema/payment_order.go @@ -91,6 +91,10 @@ func (PaymentOrder) Fields() []ent.Field { Optional(). Nillable(). MaxLen(64), + field.String("provider_key"). + Optional(). + Nillable(). + MaxLen(30), // 状态 field.String("status"). diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 519455f0..83bac21d 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -45,7 +45,7 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo if inst, instErr := s.getOrderProviderInstance(ctx, o); instErr == nil && inst != nil { instanceProviderKey = inst.ProviderKey } - expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, instanceProviderKey) + expectedProviderKey := expectedNotificationProviderKey(s.registry, o.PaymentType, psStringValue(o.ProviderKey), instanceProviderKey) if expectedProviderKey != "" && strings.TrimSpace(pk) != "" && !strings.EqualFold(expectedProviderKey, strings.TrimSpace(pk)) { s.writeAuditLog(ctx, o.ID, "PAYMENT_PROVIDER_MISMATCH", pk, map[string]any{ "expectedProvider": expectedProviderKey, @@ -69,10 +69,13 @@ func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo return s.toPaid(ctx, o, tradeNo, paid, pk) } -func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, instanceProviderKey string) string { +func expectedNotificationProviderKey(registry *payment.Registry, orderPaymentType string, orderProviderKey string, instanceProviderKey string) string { if key := strings.TrimSpace(instanceProviderKey); key != "" { return key } + if key := strings.TrimSpace(orderProviderKey); key != "" { + return key + } if registry != nil { if key := strings.TrimSpace(registry.GetProviderKey(payment.PaymentType(orderPaymentType))); key != "" { return key diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index 4cc00301..712129b0 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -198,7 +198,7 @@ func TestExpectedNotificationProviderKeyPrefersOrderInstanceProvider(t *testing. assert.Equal(t, payment.TypeEasyPay, - expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay), + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", payment.TypeEasyPay), ) } @@ -213,7 +213,7 @@ func TestExpectedNotificationProviderKeyUsesRegistryMappingForLegacyOrders(t *te assert.Equal(t, payment.TypeEasyPay, - expectedNotificationProviderKey(registry, payment.TypeAlipay, ""), + expectedNotificationProviderKey(registry, payment.TypeAlipay, "", ""), ) } @@ -222,6 +222,21 @@ func TestExpectedNotificationProviderKeyFallsBackToPaymentType(t *testing.T) { assert.Equal(t, payment.TypeWxpay, - expectedNotificationProviderKey(nil, payment.TypeWxpay, ""), + expectedNotificationProviderKey(nil, payment.TypeWxpay, "", ""), + ) +} + +func TestExpectedNotificationProviderKeyPrefersOrderSnapshotProviderKey(t *testing.T) { + t.Parallel() + + registry := payment.NewRegistry() + registry.Register(paymentFulfillmentTestProvider{ + key: payment.TypeAlipay, + supportedTypes: []payment.PaymentType{payment.TypeAlipay}, + }) + + assert.Equal(t, + payment.TypeEasyPay, + expectedNotificationProviderKey(registry, payment.TypeAlipay, payment.TypeEasyPay, ""), ) } diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index fa256be7..4b9b1872 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -251,7 +251,13 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) } - _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) + _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID). + SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)). + SetNillablePayURL(psNilIfEmpty(pr.PayURL)). + SetNillableQrCode(psNilIfEmpty(pr.QRCode)). + SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)). + SetNillableProviderKey(psNilIfEmpty(sel.ProviderKey)). + Save(ctx) if err != nil { return nil, fmt.Errorf("update order with payment details: %w", err) } diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index 80147180..f804eb8b 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -241,7 +242,10 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO if err == nil { cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) if err == nil { - providerKey := s.registry.GetProviderKey(o.PaymentType) + providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) + if providerKey == "" { + providerKey = s.registry.GetProviderKey(o.PaymentType) + } if providerKey == "" { providerKey = o.PaymentType } @@ -255,3 +259,10 @@ func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentO s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } + +func psStringValue(value *string) string { + if value == nil { + return "" + } + return *value +} diff --git a/backend/migrations/112_add_payment_order_provider_key_snapshot.sql b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql new file mode 100644 index 00000000..7ec19ae3 --- /dev/null +++ b/backend/migrations/112_add_payment_order_provider_key_snapshot.sql @@ -0,0 +1,10 @@ +ALTER TABLE payment_orders ADD COLUMN provider_key VARCHAR(30); + +UPDATE payment_orders +SET provider_key = ( + SELECT provider_key + FROM payment_provider_instances + WHERE CAST(id AS TEXT) = payment_orders.provider_instance_id +) +WHERE provider_key IS NULL + AND provider_instance_id IS NOT NULL; -- GitLab From 9bebf1c1a611e37c68a127ffee07a670460a6a09 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:53:46 +0800 Subject: [PATCH 053/265] feat: resolve payment results by resume token --- backend/internal/handler/payment_handler.go | 29 +++++ backend/internal/server/routes/payment.go | 1 + .../internal/service/payment_resume_lookup.go | 35 ++++++ .../service/payment_resume_lookup_test.go | 118 ++++++++++++++++++ frontend/src/api/payment.ts | 5 + frontend/src/views/user/PaymentResultView.vue | 12 +- .../user/__tests__/PaymentResultView.spec.ts | 26 ++++ 7 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 backend/internal/service/payment_resume_lookup.go create mode 100644 backend/internal/service/payment_resume_lookup_test.go diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index d5c3d7c8..6440cdfd 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -357,6 +357,10 @@ type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` } +type ResolveOrderByResumeTokenRequest struct { + ResumeToken string `json:"resume_token" binding:"required"` +} + // VerifyOrder actively queries the upstream payment provider to check // if payment was made, and processes it if so. // POST /api/v1/payment/orders/verify @@ -417,6 +421,31 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { }) } +// ResolveOrderPublicByResumeToken resolves a payment order from a signed resume token. +// POST /api/v1/payment/public/orders/resolve +func (h *PaymentHandler) ResolveOrderPublicByResumeToken(c *gin.Context) { + var req ResolveOrderByResumeTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + order, err := h.paymentService.GetPublicOrderByResumeToken(c.Request.Context(), req.ResumeToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, PublicOrderResult{ + ID: order.ID, + OutTradeNo: order.OutTradeNo, + Amount: order.Amount, + PayAmount: order.PayAmount, + PaymentType: order.PaymentType, + OrderType: order.OrderType, + Status: order.Status, + }) +} + // requireAuth extracts the authenticated subject from the context. // Returns the subject and true on success; on failure it writes an Unauthorized response and returns false. func requireAuth(c *gin.Context) (middleware2.AuthSubject, bool) { diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 23bd58ad..dff14a70 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -49,6 +49,7 @@ func RegisterPaymentRoutes( public := v1.Group("/payment/public") { public.POST("/orders/verify", paymentHandler.VerifyOrderPublic) + public.POST("/orders/resolve", paymentHandler.ResolveOrderPublicByResumeToken) } // --- Webhook endpoints (no auth) --- diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go new file mode 100644 index 00000000..493ca325 --- /dev/null +++ b/backend/internal/service/payment_resume_lookup.go @@ -0,0 +1,35 @@ +package service + +import ( + "context" + "fmt" + "strings" + + dbent "github.com/Wei-Shaw/sub2api/ent" +) + +func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token string) (*dbent.PaymentOrder, error) { + claims, err := s.paymentResume().ParseToken(strings.TrimSpace(token)) + if err != nil { + return nil, err + } + + order, err := s.entClient.PaymentOrder.Get(ctx, claims.OrderID) + if err != nil { + return nil, fmt.Errorf("get order by resume token: %w", err) + } + if claims.UserID > 0 && order.UserID != claims.UserID { + return nil, fmt.Errorf("resume token user mismatch") + } + if claims.ProviderInstanceID != "" && strings.TrimSpace(psStringValue(order.ProviderInstanceID)) != claims.ProviderInstanceID { + return nil, fmt.Errorf("resume token provider instance mismatch") + } + if claims.ProviderKey != "" && strings.TrimSpace(psStringValue(order.ProviderKey)) != claims.ProviderKey { + return nil, fmt.Errorf("resume token provider key mismatch") + } + if claims.PaymentType != "" && strings.TrimSpace(order.PaymentType) != claims.PaymentType { + return nil, fmt.Errorf("resume token payment type mismatch") + } + + return order, nil +} diff --git a/backend/internal/service/payment_resume_lookup_test.go b/backend/internal/service/payment_resume_lookup_test.go new file mode 100644 index 00000000..d411398e --- /dev/null +++ b/backend/internal/service/payment_resume_lookup_test.go @@ -0,0 +1,118 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/payment" + "github.com/stretchr/testify/require" +) + +func TestGetPublicOrderByResumeTokenReturnsMatchingOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-user"). + Save(ctx) + require.NoError(t, err) + + instanceID := "12" + providerKey := payment.TypeEasyPay + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-ORDER"). + SetOutTradeNo("sub2_resume_lookup"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-1"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID(instanceID). + SetProviderKey(providerKey). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: instanceID, + ProviderKey: providerKey, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + got, err := svc.GetPublicOrderByResumeToken(ctx, token) + require.NoError(t, err) + require.Equal(t, order.ID, got.ID) +} + +func TestGetPublicOrderByResumeTokenRejectsSnapshotMismatch(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + user, err := client.User.Create(). + SetEmail("resume-mismatch@example.com"). + SetPasswordHash("hash"). + SetUsername("resume-mismatch-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("RESUME-MISMATCH"). + SetOutTradeNo("sub2_resume_lookup_mismatch"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-2"). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + SetProviderInstanceID("12"). + SetProviderKey(payment.TypeEasyPay). + Save(ctx) + require.NoError(t, err) + + resumeSvc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef")) + token, err := resumeSvc.CreateToken(ResumeTokenClaims{ + OrderID: order.ID, + UserID: user.ID, + ProviderInstanceID: "99", + ProviderKey: payment.TypeEasyPay, + PaymentType: payment.TypeAlipay, + CanonicalReturnURL: "https://app.example.com/payment/result", + }) + require.NoError(t, err) + + svc := &PaymentService{ + entClient: client, + resumeService: resumeSvc, + } + + _, err = svc.GetPublicOrderByResumeToken(ctx, token) + require.Error(t, err) + require.Contains(t, err.Error(), "resume token") +} diff --git a/frontend/src/api/payment.ts b/frontend/src/api/payment.ts index 5cedb107..91b16866 100644 --- a/frontend/src/api/payment.ts +++ b/frontend/src/api/payment.ts @@ -72,6 +72,11 @@ export const paymentAPI = { return apiClient.post('/payment/public/orders/verify', { out_trade_no: outTradeNo }) }, + /** Resolve an order from a signed resume token without auth */ + resolveOrderPublicByResumeToken(resumeToken: string) { + return apiClient.post('/payment/public/orders/resolve', { resume_token: resumeToken }) + }, + /** Request a refund for a completed order */ requestRefund(id: number, data: { reason: string }) { return apiClient.post(`/payment/orders/${id}/refund-request`, data) diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index e1db3ce2..9687d1c7 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -150,7 +150,17 @@ onMounted(async () => { } } - if (orderId) { + if (!order.value && !orderId && resumeToken) { + try { + const result = await paymentAPI.resolveOrderPublicByResumeToken(resumeToken) + order.value = result.data + orderId = result.data.id + } catch (_err: unknown) { + // Resume token recovery failed, continue to legacy fallback paths. + } + } + + if (!order.value && orderId) { try { order.value = await paymentStore.pollOrderStatus(orderId) } catch (_err: unknown) { diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index b06217ab..d23a60d9 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -9,6 +9,7 @@ const routerPush = vi.hoisted(() => vi.fn()) const pollOrderStatus = vi.hoisted(() => vi.fn()) const verifyOrderPublic = vi.hoisted(() => vi.fn()) const verifyOrder = vi.hoisted(() => vi.fn()) +const resolveOrderPublicByResumeToken = vi.hoisted(() => vi.fn()) vi.mock('vue-router', async () => { const actual = await vi.importActual('vue-router') @@ -39,6 +40,7 @@ vi.mock('@/api/payment', () => ({ paymentAPI: { verifyOrderPublic, verifyOrder, + resolveOrderPublicByResumeToken, }, })) @@ -67,6 +69,7 @@ describe('PaymentResultView', () => { pollOrderStatus.mockReset() verifyOrderPublic.mockReset() verifyOrder.mockReset() + resolveOrderPublicByResumeToken.mockReset() window.localStorage.clear() }) @@ -129,4 +132,27 @@ describe('PaymentResultView', () => { expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123') expect(wrapper.text()).toContain('payment.result.success') }) + + it('resolves order by resume token when local recovery snapshot is missing', async () => { + routeState.query = { + resume_token: 'resume-77', + } + resolveOrderPublicByResumeToken.mockResolvedValue({ + data: orderFactory('PAID'), + }) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(resolveOrderPublicByResumeToken).toHaveBeenCalledWith('resume-77') + expect(wrapper.text()).toContain('payment.result.success') + expect(verifyOrderPublic).not.toHaveBeenCalled() + }) }) -- GitLab From 32059ae9d5fcf7024fadee26cf1e9c32dcc6da4b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 20:58:19 +0800 Subject: [PATCH 054/265] fix: backfill email identities on successful login --- .../internal/service/auth_oauth_email_flow.go | 6 +++ backend/internal/service/auth_service.go | 8 ++++ .../auth_service_identity_sync_test.go | 37 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index ca3403d4..4ab4e245 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -147,5 +147,11 @@ func (s *AuthService) ValidatePasswordCredentials(ctx context.Context, email, pa // RecordSuccessfulLogin updates last-login activity after a non-standard login // flow finishes with a real session. func (s *AuthService) RecordSuccessfulLogin(ctx context.Context, userID int64) { + if s != nil && s.userRepo != nil && userID > 0 { + user, err := s.userRepo.GetByID(ctx, userID) + if err == nil { + s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) + } + } s.touchUserLogin(ctx, userID) } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 40753139..dda6df04 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -430,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string if !user.IsActive() { return "", nil, ErrUserNotActive } + s.backfillEmailIdentityOnSuccessfulLogin(ctx, user) s.touchUserLogin(ctx, user.ID) // 生成JWT token @@ -802,6 +803,13 @@ func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) { } } +func (s *AuthService) backfillEmailIdentityOnSuccessfulLogin(ctx context.Context, user *User) { + if s == nil || user == nil || user.ID <= 0 { + return + } + s.ensureEmailAuthIdentity(ctx, user) +} + func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) { if s == nil || s.entClient == nil || user == nil || user.ID <= 0 { return diff --git a/backend/internal/service/auth_service_identity_sync_test.go b/backend/internal/service/auth_service_identity_sync_test.go index 5bd2b25d..fcb4813b 100644 --- a/backend/internal/service/auth_service_identity_sync_test.go +++ b/backend/internal/service/auth_service_identity_sync_test.go @@ -150,4 +150,41 @@ func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) { require.NotNil(t, storedUser.LastActiveAt) require.True(t, storedUser.LastLoginAt.After(old)) require.True(t, storedUser.LastActiveAt.After(old)) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("login@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) +} + +func TestAuthServiceRecordSuccessfulLoginBackfillsEmailIdentity(t *testing.T) { + svc, repo, client := newAuthServiceWithEnt(t) + ctx := context.Background() + + user := &service.User{ + Email: "record@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 1, + Concurrency: 1, + } + require.NoError(t, user.SetPassword("password")) + require.NoError(t, repo.Create(ctx, user)) + + svc.RecordSuccessfulLogin(ctx, user.ID) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("email"), + authidentity.ProviderKeyEQ("email"), + authidentity.ProviderSubjectEQ("record@example.com"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) } -- GitLab From bdcd3d87e530fc1bbbec4888a51ae79c1cb0e298 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:09:38 +0800 Subject: [PATCH 055/265] fix: resolve unique legacy payment providers --- .../service/payment_order_lifecycle.go | 46 +++++--- backend/internal/service/payment_refund.go | 79 ++++++++++++- .../service/payment_webhook_provider.go | 22 ++-- .../service/payment_webhook_provider_test.go | 109 ++++++++++++++++++ 4 files changed, 220 insertions(+), 36 deletions(-) diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index f804eb8b..24d8b7a2 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -5,7 +5,6 @@ import ( "fmt" "log/slog" "strconv" - "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -237,29 +236,38 @@ func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) // getOrderProvider creates a provider using the order's original instance config. // Falls back to registry lookup if instance ID is missing (legacy orders). func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) - if providerKey == "" { - providerKey = s.registry.GetProviderKey(o.PaymentType) - } - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } + inst, err := s.getOrderProviderInstance(ctx, o) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) } s.EnsureProviders(ctx) return s.registry.GetProvider(o.PaymentType) } +func (s *PaymentService) createProviderFromInstance(ctx context.Context, inst *dbent.PaymentProviderInstance) (payment.Provider, error) { + if inst == nil { + return nil, fmt.Errorf("payment provider instance is missing") + } + + cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) + if err != nil { + return nil, fmt.Errorf("load provider instance config: %w", err) + } + if inst.PaymentMode != "" { + cfg["paymentMode"] = inst.PaymentMode + } + + instID := strconv.FormatInt(int64(inst.ID), 10) + prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) + if err != nil { + return nil, fmt.Errorf("create provider from instance: %w", err) + } + return prov, nil +} + func psStringValue(value *string) string { if value == nil { return "" diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index c5bda763..01eecff8 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" + "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) @@ -19,18 +20,90 @@ import ( // --- Refund Flow --- // getOrderProviderInstance looks up the provider instance that processed this order. -// Returns nil, nil for legacy orders without provider_instance_id. +// For legacy orders without provider_instance_id, it resolves only when the +// enabled instance is uniquely identifiable from the stored order fields. func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { - if o.ProviderInstanceID == nil || *o.ProviderInstanceID == "" { + if s == nil || s.entClient == nil || o == nil { return nil, nil } - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) + + instIDStr := strings.TrimSpace(psStringValue(o.ProviderInstanceID)) + if instIDStr == "" { + return s.resolveUniqueLegacyOrderProviderInstance(ctx, o) + } + + instID, err := strconv.ParseInt(instIDStr, 10, 64) if err != nil { return nil, nil } return s.entClient.PaymentProviderInstance.Get(ctx, instID) } +func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) + if providerKey != "" { + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where( + paymentproviderinstance.EnabledEQ(true), + paymentproviderinstance.ProviderKeyEQ(providerKey), + ). + All(ctx) + if err != nil { + return nil, err + } + if len(instances) == 1 { + return instances[0], nil + } + return nil, nil + } + + paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) + if paymentType == "" { + return nil, nil + } + + instances, err := s.entClient.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.EnabledEQ(true)). + All(ctx) + if err != nil { + return nil, err + } + + var matched []*dbent.PaymentProviderInstance + for _, inst := range instances { + if psLegacyOrderMatchesInstance(paymentType, inst) { + matched = append(matched, inst) + } + } + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil +} + +func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool { + if inst == nil { + return false + } + + baseType := payment.GetBasePaymentType(strings.TrimSpace(orderPaymentType)) + instanceProviderKey := strings.TrimSpace(inst.ProviderKey) + if baseType == "" { + return false + } + + if baseType == payment.TypeStripe { + return instanceProviderKey == payment.TypeStripe + } + if instanceProviderKey == payment.TypeStripe { + return false + } + if instanceProviderKey == baseType { + return true + } + return payment.InstanceSupportsType(inst.SupportedTypes, baseType) +} + func (s *PaymentService) RequestRefund(ctx context.Context, oid, uid int64, reason string) error { o, err := s.validateRefundRequest(ctx, oid, uid) if err != nil { diff --git a/backend/internal/service/payment_webhook_provider.go b/backend/internal/service/payment_webhook_provider.go index a877db2b..289d63ed 100644 --- a/backend/internal/service/payment_webhook_provider.go +++ b/backend/internal/service/payment_webhook_provider.go @@ -4,14 +4,12 @@ import ( "context" "fmt" "log/slog" - "strconv" "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentproviderinstance" "github.com/Wei-Shaw/sub2api/internal/payment" - "github.com/Wei-Shaw/sub2api/internal/payment/provider" ) // GetWebhookProvider returns the provider instance that should verify a webhook. @@ -24,6 +22,13 @@ func (s *PaymentService) GetWebhookProvider(ctx context.Context, providerKey, ou if psHasPinnedProviderInstance(order) { return s.getPinnedOrderProvider(ctx, order) } + inst, err := s.getOrderProviderInstance(ctx, order) + if err != nil { + return nil, fmt.Errorf("load order provider instance: %w", err) + } + if inst != nil { + return s.createProviderFromInstance(ctx, inst) + } if !s.webhookRegistryFallbackAllowed(ctx, providerKey) { return nil, fmt.Errorf("webhook provider fallback is ambiguous for %s", providerKey) } @@ -48,18 +53,7 @@ func (s *PaymentService) getPinnedOrderProvider(ctx context.Context, o *dbent.Pa if inst == nil { return nil, fmt.Errorf("order %d provider instance is missing", o.ID) } - - instID := strconv.FormatInt(int64(inst.ID), 10) - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, int64(inst.ID)) - if err != nil { - return nil, fmt.Errorf("load provider instance config: %w", err) - } - - prov, err := provider.CreateProvider(inst.ProviderKey, instID, cfg) - if err != nil { - return nil, fmt.Errorf("create pinned provider: %w", err) - } - return prov, nil + return s.createProviderFromInstance(ctx, inst) } func (s *PaymentService) webhookRegistryFallbackAllowed(ctx context.Context, providerKey string) bool { diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index 85c296de..33e4186d 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -4,13 +4,17 @@ package service import ( "context" + "encoding/json" "testing" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/require" ) +const webhookProviderTestEncryptionKey = "0123456789abcdef0123456789abcdef" + type webhookProviderTestDouble struct { key string types []payment.PaymentType @@ -32,6 +36,111 @@ func (p webhookProviderTestDouble) Refund(context.Context, payment.RefundRequest panic("unexpected call") } +func encryptWebhookProviderConfig(t *testing.T, config map[string]string) string { + t.Helper() + + data, err := json.Marshal(config) + require.NoError(t, err) + + encrypted, err := payment.Encrypt(string(data), []byte(webhookProviderTestEncryptionKey)) + require.NoError(t, err) + return encrypted +} + +func newWebhookProviderTestLoadBalancer(client *dbent.Client) payment.LoadBalancer { + return payment.NewDefaultLoadBalancer(client, []byte(webhookProviderTestEncryptionKey)) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyProviderKey(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-a"). + SetConfig(encryptWebhookProviderConfig(t, map[string]string{"secretKey": "sk_test_legacy_provider_key"})). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceResolvesUniqueLegacyPaymentType(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + inst, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpayDirect, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, inst.ID, got.ID) +} + +func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeEasyPay). + SetName("easypay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-a"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeWxpay, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) -- GitLab From 5adefb466b9c7a1cd02e2c9f1c0221c08d38e5bf Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:24:33 +0800 Subject: [PATCH 056/265] fix: finalize oauth identity bindings --- .../handler/auth_oauth_pending_flow.go | 152 +++++++++++++++++- .../handler/auth_oauth_pending_flow_test.go | 11 +- backend/internal/handler/auth_wechat_oauth.go | 96 +++++++++++ .../handler/auth_wechat_oauth_test.go | 124 ++++++++++++++ 4 files changed, 376 insertions(+), 7 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 6d6564e8..21ed2bc6 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -10,6 +10,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -309,6 +310,14 @@ func cloneOAuthMetadata(values map[string]any) map[string]any { return cloned } +func mergeOAuthMetadata(base map[string]any, overlay map[string]any) map[string]any { + merged := cloneOAuthMetadata(base) + for key, value := range overlay { + merged[key] = value + } + return merged +} + func normalizeAdoptedOAuthDisplayName(value string) string { value = strings.TrimSpace(value) if len([]rune(value)) > 100 { @@ -558,6 +567,10 @@ func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { } func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + if session != nil && strings.EqualFold(strings.TrimSpace(session.ProviderType), "wechat") { + return ensurePendingWeChatOAuthIdentityForUser(ctx, tx, session, userID) + } + client := tx.Client() identity, err := client.AuthIdentity.Query(). Where( @@ -588,14 +601,149 @@ func ensurePendingOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, sessio return create.Save(ctx) } +func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, session *dbent.PendingAuthSession, userID int64) (*dbent.AuthIdentity, error) { + client := tx.Client() + providerType := strings.TrimSpace(session.ProviderType) + providerKey := strings.TrimSpace(session.ProviderKey) + providerSubject := strings.TrimSpace(session.ProviderSubject) + channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) + channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) + channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) + metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if identity != nil && identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + var legacyOpenIDIdentity *dbent.AuthIdentity + if channelSubject != "" && channelSubject != providerSubject { + legacyOpenIDIdentity, err = client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(channelSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + + switch { + case identity != nil: + update := client.AuthIdentity.UpdateOneID(identity.ID). + SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + case legacyOpenIDIdentity != nil: + update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderSubject(providerSubject). + SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) + if issuer := oauthIdentityIssuer(session); issuer != nil { + update = update.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, err + } + default: + create := client.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetMetadata(metadata) + if issuer := oauthIdentityIssuer(session); issuer != nil { + create = create.SetIssuer(strings.TrimSpace(*issuer)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, err + } + } + + if channel == "" || channelAppID == "" || channelSubject == "" { + return identity, nil + } + + channelRecord, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(channelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, err + } + if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + + channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) + if channelRecord == nil { + if _, err := client.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channel). + SetChannelAppID(channelAppID). + SetChannelSubject(channelSubject). + SetMetadata(channelMetadata). + Save(ctx); err != nil { + return nil, err + } + return identity, nil + } + + _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + SetIdentityID(identity.ID). + SetMetadata(channelMetadata). + Save(ctx) + if err != nil { + return nil, err + } + return identity, nil +} + +func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { + if channel == nil { + return map[string]any{} + } + return cloneOAuthMetadata(channel.Metadata) +} + func shouldBindPendingOAuthIdentity(session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision) bool { if session == nil || decision == nil { return false } - if strings.EqualFold(strings.TrimSpace(session.Intent), "bind_current_user") { + switch strings.ToLower(strings.TrimSpace(session.Intent)) { + case "bind_current_user", "login", "adopt_existing_user_by_email": return true + default: + return decision.AdoptDisplayName || decision.AdoptAvatar } - return decision.AdoptDisplayName || decision.AdoptAvatar } func applyPendingOAuthBinding( diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index ae506e52..89accd60 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -372,7 +372,7 @@ func TestExchangePendingOAuthCompletionBindCurrentUserOwnershipConflict(t *testi require.Nil(t, storedSession.ConsumedAt) } -func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *testing.T) { +func TestExchangePendingOAuthCompletionLoginFalseFalseBindsIdentityWithoutAdoption(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -420,21 +420,22 @@ func TestExchangePendingOAuthCompletionLoginFalseFalseDoesNotBindIdentity(t *tes require.Equal(t, http.StatusOK, recorder.Code) - identityCount, err := client.AuthIdentity.Query(). + identity, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("linuxdo"), authidentity.ProviderKeyEQ("linuxdo"), authidentity.ProviderSubjectEQ("login-false-123"), ). - Count(ctx) + Only(ctx) require.NoError(t, err) - require.Zero(t, identityCount) + require.Equal(t, userEntity.ID, identity.UserID) decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) require.NoError(t, err) - require.Nil(t, decision.IdentityID) + require.NotNil(t, decision.IdentityID) + require.Equal(t, identity.ID, *decision.IdentityID) require.False(t, decision.AdoptDisplayName) require.False(t, decision.AdoptAvatar) diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index f0755f1f..b6d47670 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -242,7 +242,18 @@ func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) return } + if existingIdentityUser == nil { + existingIdentityUser, err = h.findWeChatUserByLegacyOpenID(c.Request.Context(), identityRef, cfg, openid) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + } if existingIdentityUser != nil { + if err := h.ensureWeChatRuntimeIdentityBinding(c.Request.Context(), existingIdentityUser.ID, identityRef, upstreamClaims); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), existingIdentityUser.Email, username, "") if err != nil { redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) @@ -511,6 +522,91 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return nil } +func (h *AuthHandler) findWeChatUserByLegacyOpenID( + ctx context.Context, + identity service.PendingAuthIdentityKey, + cfg wechatOAuthConfig, + openid string, +) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + openid = strings.TrimSpace(openid) + channel := strings.TrimSpace(cfg.mode) + channelAppID := strings.TrimSpace(cfg.appID) + if openid != "" && channel != "" && channelAppID != "" { + record, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentitychannel.ChannelEQ(channel), + authidentitychannel.ChannelAppIDEQ(channelAppID), + authidentitychannel.ChannelSubjectEQ(openid), + ). + WithIdentity(func(q *dbent.AuthIdentityQuery) { + q.WithUser() + }). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil { + return record.Edges.Identity.Edges.User, nil + } + } + + if openid == "" { + return nil, nil + } + + record, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), + authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderSubjectEQ(openid), + ). + WithUser(). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + return record.Edges.User, nil +} + +func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( + ctx context.Context, + userID int64, + identity service.PendingAuthIdentityKey, + upstreamClaims map[string]any, +) error { + client := h.entClient() + if client == nil { + return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + tx, err := client.Tx(ctx) + if err != nil { + return infraerrors.InternalServer("AUTH_IDENTITY_BIND_FAILED", "failed to begin wechat identity repair transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + _, err = ensurePendingOAuthIdentityForUser(dbent.NewTxContext(ctx, tx), tx, &dbent.PendingAuthSession{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + UpstreamIdentityClaims: cloneOAuthMetadata(upstreamClaims), + }, userID) + if err != nil { + return err + } + return tx.Commit() +} + func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, c *gin.Context) (wechatOAuthConfig, error) { mode, err := resolveWeChatOAuthMode(rawMode, c) if err != nil { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index 1ff80e1b..dd022fb9 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -15,6 +15,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/enttest" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" "github.com/Wei-Shaw/sub2api/ent/pendingauthsession" @@ -563,6 +564,19 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.Equal(t, "WeChat Display", identity.Metadata["display_name"]) require.Equal(t, "https://cdn.example/wechat.png", identity.Metadata["avatar_url"]) + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "union-456", channel.Metadata["unionid"]) + decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(pendingSession.ID)). Only(ctx) @@ -579,6 +593,116 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing require.NotNil(t, consumed.ConsumedAt) } +func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy WeChat","headimgurl":"https://cdn.example/legacy.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthProviderKey). + SetProviderSubject("openid-123"). + SetMetadata(map[string]any{"openid": "openid-123"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + openIDIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("openid-123"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, openIDIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() -- GitLab From f65429145e57c7bb61844fc978a3ef92e53710ef Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:31:05 +0800 Subject: [PATCH 057/265] fix: route legacy linuxdo users to account binding --- .../internal/handler/auth_linuxdo_oauth.go | 59 ++++++++++++++ .../handler/auth_linuxdo_oauth_test.go | 76 +++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index c3ec3804..a0760a3b 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -15,6 +15,8 @@ import ( "time" "unicode/utf8" + dbent "github.com/Wei-Shaw/sub2api/ent" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" @@ -237,6 +239,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") return } + compatEmail := strings.TrimSpace(email) // 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。 // 统一使用基于 subject 的稳定合成邮箱来做账号绑定。 @@ -255,6 +258,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { "suggested_display_name": displayName, "suggested_avatar_url": avatarURL, } + if compatEmail != "" && !strings.EqualFold(strings.TrimSpace(compatEmail), strings.TrimSpace(email)) { + upstreamClaims["compat_email"] = compatEmail + } if intent == oauthIntentBindCurrentUser { targetUserID, err := h.readOAuthBindUserIDFromCookie(c, linuxDoOAuthBindUserCookieName) if err != nil { @@ -314,6 +320,33 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { return } + compatEmailUser, err := h.findLinuxDoCompatEmailUser(c.Request.Context(), compatEmail) + if err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + if compatEmailUser != nil { + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: "adopt_existing_user_by_email", + Identity: identityKey, + TargetUserID: &compatEmailUser.ID, + ResolvedEmail: compatEmailUser.Email, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{ + "redirect": redirectTo, + "step": "bind_login_required", + "email": compatEmailUser.Email, + }, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { if err := h.createOAuthEmailRequiredPendingSession(c, identityKey, redirectTo, browserSessionKey, upstreamClaims); err != nil { redirectOAuthError(c, frontendCallback, "session_error", "failed to continue oauth login", "") @@ -372,6 +405,32 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { redirectToFrontendCallback(c, frontendCallback) } +func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntity, err := client.User.Query(). + Where(dbuser.EmailEqualFold(email)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, nil + } + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + return userEntity, nil +} + type completeLinuxDoOAuthRequest struct { InvitationCode string `json:"invitation_code" binding:"required"` AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go index 765779b5..fb57e570 100644 --- a/backend/internal/handler/auth_linuxdo_oauth_test.go +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -300,6 +300,82 @@ func TestLinuxDoOAuthCallbackCreatesLoginPendingSessionForExistingUser(t *testin require.Nil(t, completion["error"]) } +func TestLinuxDoOAuthCallbackCreatesBindPendingSessionForCompatEmailUser(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"linuxdo-access","token_type":"Bearer","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"321","email":"legacy@example.com","username":"linuxdo_user","name":"LinuxDo Display","avatar_url":"https://cdn.example/linuxdo.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + + handler, client := newLinuxDoOAuthHandlerAndClient(t, false, config.LinuxDoConnectConfig{ + Enabled: true, + ClientID: "linuxdo-client", + ClientSecret: "linuxdo-secret", + AuthorizeURL: upstream.URL + "/authorize", + TokenURL: upstream.URL + "/token", + UserInfoURL: upstream.URL + "/userinfo", + Scopes: "read", + RedirectURL: "https://api.example.com/api/v1/auth/oauth/linuxdo/callback", + FrontendRedirectURL: "/auth/linuxdo/callback", + TokenAuthMethod: "client_secret_post", + UsePKCE: true, + }) + defer client.Close() + + ctx := context.Background() + existingUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/linuxdo/callback?code=code-compat&state=state-compat", nil) + req.AddCookie(encodedCookie(linuxDoOAuthStateCookieName, "state-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(linuxDoOAuthVerifierCookie, "verifier-compat")) + req.AddCookie(encodedCookie(linuxDoOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-compat")) + c.Request = req + + handler.LinuxDoOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/linuxdo/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.Equal(t, "adopt_existing_user_by_email", session.Intent) + require.NotNil(t, session.TargetUserID) + require.Equal(t, existingUser.ID, *session.TargetUserID) + require.Equal(t, existingUser.Email, session.ResolvedEmail) + require.Equal(t, "legacy@example.com", session.UpstreamIdentityClaims["compat_email"]) + + completion := session.LocalFlowState[oauthCompletionResponseKey].(map[string]any) + require.Equal(t, "/dashboard", completion["redirect"]) + require.Equal(t, "bind_login_required", completion["step"]) + require.Equal(t, existingUser.Email, completion["email"]) + _, hasAccessToken := completion["access_token"] + require.False(t, hasAccessToken) +} + func TestLinuxDoOAuthCallbackCreatesInvitationPendingSessionWhenSignupRequiresInvite(t *testing.T) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { -- GitLab From 422f60a14513b8c9df33bdf8209d82fb6888d7b9 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:42:35 +0800 Subject: [PATCH 058/265] fix: normalize legacy wechat auth identity keys --- .../handler/auth_oauth_pending_flow.go | 109 ++++++++-- backend/internal/handler/auth_wechat_oauth.go | 122 ++++++++--- .../handler/auth_wechat_oauth_test.go | 192 ++++++++++++++++++ ...3_normalize_legacy_wechat_provider_key.sql | 89 ++++++++ 4 files changed, 464 insertions(+), 48 deletions(-) create mode 100644 backend/migrations/113_normalize_legacy_wechat_provider_key.sql diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 21ed2bc6..fd35e4e5 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -606,39 +606,42 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, providerType := strings.TrimSpace(session.ProviderType) providerKey := strings.TrimSpace(session.ProviderKey) providerSubject := strings.TrimSpace(session.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(providerKey) channel := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel")) channelAppID := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_app_id")) channelSubject := strings.TrimSpace(pendingSessionStringValue(session.UpstreamIdentityClaims, "channel_subject")) metadata := cloneOAuthMetadata(session.UpstreamIdentityClaims) - identity, err := client.AuthIdentity.Query(). + identityRecords, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(providerSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if identity != nil && identity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + identity, hasCanonicalKey, err := chooseWeChatIdentityForUser(identityRecords, userID, providerKey) + if err != nil { + return nil, err } var legacyOpenIDIdentity *dbent.AuthIdentity if channelSubject != "" && channelSubject != providerSubject { - legacyOpenIDIdentity, err = client.AuthIdentity.Query(). + legacyOpenIDRecords, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(channelSubject), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if legacyOpenIDIdentity != nil && legacyOpenIDIdentity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + legacyOpenIDIdentity, _, err = chooseWeChatIdentityForUser(legacyOpenIDRecords, userID, providerKey) + if err != nil { + return nil, err } } @@ -646,6 +649,9 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, case identity != nil: update := client.AuthIdentity.UpdateOneID(identity.ID). SetMetadata(mergeOAuthMetadata(identity.Metadata, metadata)) + if !strings.EqualFold(strings.TrimSpace(identity.ProviderKey), providerKey) && !hasCanonicalKey { + update = update.SetProviderKey(providerKey) + } if issuer := oauthIdentityIssuer(session); issuer != nil { update = update.SetIssuer(strings.TrimSpace(*issuer)) } @@ -655,6 +661,7 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, } case legacyOpenIDIdentity != nil: update := client.AuthIdentity.UpdateOneID(legacyOpenIDIdentity.ID). + SetProviderKey(providerKey). SetProviderSubject(providerSubject). SetMetadata(mergeOAuthMetadata(legacyOpenIDIdentity.Metadata, metadata)) if issuer := oauthIdentityIssuer(session); issuer != nil { @@ -684,21 +691,22 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, return identity, nil } - channelRecord, err := client.AuthIdentityChannel.Query(). + channelRecords, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ(providerType), - authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ProviderKeyIn(providerKeys...), authidentitychannel.ChannelEQ(channel), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(channelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, err } - if channelRecord != nil && channelRecord.Edges.Identity != nil && channelRecord.Edges.Identity.UserID != userID { - return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + channelRecord, hasCanonicalChannelKey, err := chooseWeChatChannelForUser(channelRecords, userID, providerKey) + if err != nil { + return nil, err } channelMetadata := mergeOAuthMetadata(channelRecordMetadata(channelRecord), metadata) @@ -717,16 +725,75 @@ func ensurePendingWeChatOAuthIdentityForUser(ctx context.Context, tx *dbent.Tx, return identity, nil } - _, err = client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). + updateChannel := client.AuthIdentityChannel.UpdateOneID(channelRecord.ID). SetIdentityID(identity.ID). - SetMetadata(channelMetadata). - Save(ctx) + SetMetadata(channelMetadata) + if !strings.EqualFold(strings.TrimSpace(channelRecord.ProviderKey), providerKey) && !hasCanonicalChannelKey { + updateChannel = updateChannel.SetProviderKey(providerKey) + } + _, err = updateChannel.Save(ctx) if err != nil { return nil, err } return identity, nil } +func chooseWeChatIdentityForUser(records []*dbent.AuthIdentity, userID int64, preferredProviderKey string) (*dbent.AuthIdentity, bool, error) { + var preferred *dbent.AuthIdentity + var fallback *dbent.AuthIdentity + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.UserID != userID { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + +func chooseWeChatChannelForUser(records []*dbent.AuthIdentityChannel, userID int64, preferredProviderKey string) (*dbent.AuthIdentityChannel, bool, error) { + var preferred *dbent.AuthIdentityChannel + var fallback *dbent.AuthIdentityChannel + hasCanonicalKey := false + for _, record := range records { + if record == nil { + continue + } + if record.Edges.Identity != nil && record.Edges.Identity.UserID != userID { + return nil, false, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + if strings.EqualFold(strings.TrimSpace(record.ProviderKey), preferredProviderKey) { + hasCanonicalKey = true + if preferred == nil { + preferred = record + } + continue + } + if fallback == nil { + fallback = record + } + } + if preferred != nil { + return preferred, hasCanonicalKey, nil + } + return fallback, hasCanonicalKey, nil +} + func channelRecordMetadata(channel *dbent.AuthIdentityChannel) map[string]any { if channel == nil { return map[string]any{} diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index b6d47670..95993dfc 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -34,6 +34,7 @@ const ( wechatOAuthDefaultRedirectTo = "/dashboard" wechatOAuthDefaultFrontendCB = "/auth/wechat/callback" wechatOAuthProviderKey = "wechat-main" + wechatOAuthLegacyProviderKey = "wechat" wechatOAuthIntentLogin = "login" wechatOAuthIntentBind = "bind_current_user" @@ -483,18 +484,20 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") } - identity, err := client.AuthIdentity.Query(). + identities, err := client.AuthIdentity.Query(). Where( authidentity.ProviderTypeEQ("wechat"), - authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), authidentity.ProviderSubjectEQ(strings.TrimSpace(providerSubject)), ). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return infraerrors.InternalServer("WECHAT_BIND_LOOKUP_FAILED", "failed to inspect wechat identity ownership").WithCause(err) } - if identity != nil && identity.UserID != userID { - return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + for _, identity := range identities { + if identity != nil && identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } } channelSubject = strings.TrimSpace(channelSubject) @@ -503,21 +506,23 @@ func (h *AuthHandler) ensureWeChatBindOwnership( return nil } - channel, err := client.AuthIdentityChannel.Query(). + channels, err := client.AuthIdentityChannel.Query(). Where( authidentitychannel.ProviderTypeEQ("wechat"), - authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ProviderKeyIn(wechatCompatibleProviderKeys(wechatOAuthProviderKey)...), authidentitychannel.ChannelEQ(strings.TrimSpace(cfg.mode)), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(channelSubject), ). WithIdentity(). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return infraerrors.InternalServer("WECHAT_BIND_CHANNEL_LOOKUP_FAILED", "failed to inspect wechat identity channel ownership").WithCause(err) } - if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { - return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + for _, channel := range channels { + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } } return nil } @@ -533,14 +538,34 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") } + providerType := strings.TrimSpace(identity.ProviderType) + providerSubject := strings.TrimSpace(identity.ProviderSubject) + providerKeys := wechatCompatibleProviderKeys(identity.ProviderKey) + if providerSubject != "" { + records, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + WithUser(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if user, err := singleWeChatIdentityUser(records); err != nil || user != nil { + return user, err + } + } + openid = strings.TrimSpace(openid) channel := strings.TrimSpace(cfg.mode) channelAppID := strings.TrimSpace(cfg.appID) if openid != "" && channel != "" && channelAppID != "" { - record, err := client.AuthIdentityChannel.Query(). + records, err := client.AuthIdentityChannel.Query(). Where( - authidentitychannel.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), - authidentitychannel.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(providerKeys...), authidentitychannel.ChannelEQ(channel), authidentitychannel.ChannelAppIDEQ(channelAppID), authidentitychannel.ChannelSubjectEQ(openid), @@ -548,12 +573,12 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( WithIdentity(func(q *dbent.AuthIdentityQuery) { q.WithUser() }). - Only(ctx) - if err != nil && !dbent.IsNotFound(err) { + All(ctx) + if err != nil { return nil, infraerrors.InternalServer("AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) } - if record != nil && record.Edges.Identity != nil && record.Edges.Identity.Edges.User != nil { - return record.Edges.Identity.Edges.User, nil + if user, err := singleWeChatChannelUser(records); err != nil || user != nil { + return user, err } } @@ -561,21 +586,64 @@ func (h *AuthHandler) findWeChatUserByLegacyOpenID( return nil, nil } - record, err := client.AuthIdentity.Query(). + records, err := client.AuthIdentity.Query(). Where( - authidentity.ProviderTypeEQ(strings.TrimSpace(identity.ProviderType)), - authidentity.ProviderKeyEQ(strings.TrimSpace(identity.ProviderKey)), + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(providerKeys...), authidentity.ProviderSubjectEQ(openid), ). WithUser(). - Only(ctx) + All(ctx) if err != nil { - if dbent.IsNotFound(err) { - return nil, nil - } return nil, infraerrors.InternalServer("AUTH_IDENTITY_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) } - return record.Edges.User, nil + return singleWeChatIdentityUser(records) +} + +func wechatCompatibleProviderKeys(providerKey string) []string { + preferred := strings.TrimSpace(providerKey) + if preferred == "" { + preferred = wechatOAuthProviderKey + } + keys := []string{preferred} + if !strings.EqualFold(preferred, wechatOAuthLegacyProviderKey) { + keys = append(keys, wechatOAuthLegacyProviderKey) + } + return keys +} + +func singleWeChatIdentityUser(records []*dbent.AuthIdentity) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.User + continue + } + if resolved.ID != record.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + } + return resolved, nil +} + +func singleWeChatChannelUser(records []*dbent.AuthIdentityChannel) (*dbent.User, error) { + var resolved *dbent.User + for _, record := range records { + if record == nil || record.Edges.Identity == nil || record.Edges.Identity.Edges.User == nil { + continue + } + if resolved == nil { + resolved = record.Edges.Identity.Edges.User + continue + } + if resolved.ID != record.Edges.Identity.Edges.User.ID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + } + return resolved, nil } func (h *AuthHandler) ensureWeChatRuntimeIdentityBinding( diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index dd022fb9..def9d5d6 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -467,6 +467,88 @@ func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { require.Zero(t, count) } +func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Conflict Nick","headimgurl":"https://cdn.example/conflict.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetUsername("owner"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + currentUser, err := client.User.Create(). + SetEmail("current@example.com"). + SetUsername("current"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthIntentCookieName, wechatOAuthIntentBind)) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(wechatOAuthBindUserCookieName, buildEncodedOAuthBindUserCookie(t, currentUser.ID, "test-secret"))) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Nil(t, findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName)) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "ownership_conflict", "AUTH_IDENTITY_OWNERSHIP_CONFLICT") + + count, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, count) +} + func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") @@ -703,6 +785,116 @@ func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { require.Equal(t, repairedIdentity.ID, channel.IdentityID) } +func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { + t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") + t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") + t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") + + originalAccessTokenURL := wechatOAuthAccessTokenURL + originalUserInfoURL := wechatOAuthUserInfoURL + t.Cleanup(func() { + wechatOAuthAccessTokenURL = originalAccessTokenURL + wechatOAuthUserInfoURL = originalUserInfoURL + }) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/sns/oauth2/access_token"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"wechat-access","openid":"openid-123","unionid":"union-456","scope":"snsapi_login"}`)) + case strings.Contains(r.URL.Path, "/sns/userinfo"): + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"openid":"openid-123","unionid":"union-456","nickname":"Legacy Canonical","headimgurl":"https://cdn.example/legacy-canonical.png"}`)) + default: + http.NotFound(w, r) + } + })) + defer upstream.Close() + wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" + wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" + + handler, client := newWeChatOAuthTestHandler(t, false) + defer client.Close() + + ctx := context.Background() + legacyUser, err := client.User.Create(). + SetEmail("legacy@example.com"). + SetUsername("legacy-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + legacyIdentity, err := client.AuthIdentity.Create(). + SetUserID(legacyUser.ID). + SetProviderType("wechat"). + SetProviderKey(wechatOAuthLegacyProviderKey). + SetProviderSubject("union-456"). + SetMetadata(map[string]any{"unionid": "union-456"}). + Save(ctx) + require.NoError(t, err) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/callback?code=wechat-code&state=state-123", nil) + req.Host = "api.example.com" + req.AddCookie(encodedCookie(wechatOAuthStateCookieName, "state-123")) + req.AddCookie(encodedCookie(wechatOAuthRedirectCookieName, "/dashboard")) + req.AddCookie(encodedCookie(wechatOAuthModeCookieName, "open")) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-123")) + c.Request = req + + handler.WeChatOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + require.Equal(t, "/auth/wechat/callback", recorder.Header().Get("Location")) + + sessionCookie := findCookie(recorder.Result().Cookies(), oauthPendingSessionCookieName) + require.NotNil(t, sessionCookie) + + session, err := client.PendingAuthSession.Query(). + Where(pendingauthsession.SessionTokenEQ(decodeCookieValueForTest(t, sessionCookie.Value))). + Only(ctx) + require.NoError(t, err) + require.NotNil(t, session.TargetUserID) + require.Equal(t, legacyUser.ID, *session.TargetUserID) + require.Equal(t, legacyUser.Email, session.ResolvedEmail) + + repairedIdentity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, legacyIdentity.ID, repairedIdentity.ID) + require.Equal(t, legacyUser.ID, repairedIdentity.UserID) + + legacyIdentityCount, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ(wechatOAuthLegacyProviderKey), + authidentity.ProviderSubjectEQ("union-456"), + ). + Count(ctx) + require.NoError(t, err) + require.Zero(t, legacyIdentityCount) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ(wechatOAuthProviderKey), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open-app"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, repairedIdentity.ID, channel.IdentityID) +} + func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { t.Helper() diff --git a/backend/migrations/113_normalize_legacy_wechat_provider_key.sql b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql new file mode 100644 index 00000000..15610af0 --- /dev/null +++ b/backend/migrations/113_normalize_legacy_wechat_provider_key.sql @@ -0,0 +1,89 @@ +UPDATE auth_identities AS ai +SET + provider_key = 'wechat-main', + metadata = COALESCE(ai.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identities AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject + ); + +UPDATE auth_identity_channels AS channel +SET + provider_key = 'wechat-main', + metadata = COALESCE(channel.metadata, '{}'::jsonb) || jsonb_build_object( + 'legacy_provider_key', 'wechat', + 'normalized_by_migration', '113_normalize_legacy_wechat_provider_key' + ), + updated_at = NOW() +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' + AND NOT EXISTS ( + SELECT 1 + FROM auth_identity_channels AS canon + WHERE canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject + ); + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_provider_key_conflict', + CAST(ai.id AS TEXT), + jsonb_build_object( + 'legacy_identity_id', ai.id, + 'legacy_user_id', ai.user_id, + 'provider_subject', ai.provider_subject, + 'canonical_identity_id', canon.id, + 'canonical_user_id', canon.user_id, + 'same_user', canon.user_id = ai.user_id, + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identities AS ai +JOIN auth_identities AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.provider_subject = ai.provider_subject +WHERE ai.provider_type = 'wechat' + AND ai.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; + +INSERT INTO auth_identity_migration_reports (report_type, report_key, details) +SELECT + 'wechat_channel_provider_key_conflict', + CAST(channel.id AS TEXT), + jsonb_build_object( + 'legacy_channel_id', channel.id, + 'legacy_identity_id', channel.identity_id, + 'canonical_channel_id', canon.id, + 'canonical_identity_id', canon.identity_id, + 'channel', channel.channel, + 'channel_app_id', channel.channel_app_id, + 'channel_subject', channel.channel_subject, + 'same_user', COALESCE(legacy_identity.user_id = canonical_identity.user_id, FALSE), + 'migration', '113_normalize_legacy_wechat_provider_key' + ) +FROM auth_identity_channels AS channel +JOIN auth_identity_channels AS canon + ON canon.provider_type = 'wechat' + AND canon.provider_key = 'wechat-main' + AND canon.channel = channel.channel + AND canon.channel_app_id = channel.channel_app_id + AND canon.channel_subject = channel.channel_subject +LEFT JOIN auth_identities AS legacy_identity + ON legacy_identity.id = channel.identity_id +LEFT JOIN auth_identities AS canonical_identity + ON canonical_identity.id = canon.identity_id +WHERE channel.provider_type = 'wechat' + AND channel.provider_key = 'wechat' +ON CONFLICT (report_type, report_key) DO NOTHING; -- GitLab From b30982219962c96d4e7c33d6aaeff4171eccbe88 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:46:24 +0800 Subject: [PATCH 059/265] fix: tighten legacy payment provider resolution --- backend/internal/service/payment_refund.go | 36 +++++++---- .../service/payment_webhook_provider_test.go | 64 +++++++++++++++++++ 2 files changed, 86 insertions(+), 14 deletions(-) diff --git a/backend/internal/service/payment_refund.go b/backend/internal/service/payment_refund.go index 01eecff8..57469fa3 100644 --- a/backend/internal/service/payment_refund.go +++ b/backend/internal/service/payment_refund.go @@ -21,7 +21,7 @@ import ( // getOrderProviderInstance looks up the provider instance that processed this order. // For legacy orders without provider_instance_id, it resolves only when the -// enabled instance is uniquely identifiable from the stored order fields. +// historical instance is uniquely identifiable from the stored order fields. func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { if s == nil || s.entClient == nil || o == nil { return nil, nil @@ -40,45 +40,53 @@ func (s *PaymentService) getOrderProviderInstance(ctx context.Context, o *dbent. } func (s *PaymentService) resolveUniqueLegacyOrderProviderInstance(ctx context.Context, o *dbent.PaymentOrder) (*dbent.PaymentProviderInstance, error) { + paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) providerKey := strings.TrimSpace(psStringValue(o.ProviderKey)) if providerKey != "" { instances, err := s.entClient.PaymentProviderInstance.Query(). - Where( - paymentproviderinstance.EnabledEQ(true), - paymentproviderinstance.ProviderKeyEQ(providerKey), - ). + Where(paymentproviderinstance.ProviderKeyEQ(providerKey)). All(ctx) if err != nil { return nil, err } - if len(instances) == 1 { - return instances[0], nil + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) + if len(matched) == 1 { + return matched[0], nil } return nil, nil } - paymentType := payment.GetBasePaymentType(strings.TrimSpace(o.PaymentType)) if paymentType == "" { return nil, nil } instances, err := s.entClient.PaymentProviderInstance.Query(). - Where(paymentproviderinstance.EnabledEQ(true)). All(ctx) if err != nil { return nil, err } + matched := psFilterLegacyOrderProviderInstances(paymentType, instances) + if len(matched) == 1 { + return matched[0], nil + } + return nil, nil +} + +func psFilterLegacyOrderProviderInstances(orderPaymentType string, instances []*dbent.PaymentProviderInstance) []*dbent.PaymentProviderInstance { + if len(instances) == 0 { + return nil + } + if strings.TrimSpace(orderPaymentType) == "" { + return instances + } var matched []*dbent.PaymentProviderInstance for _, inst := range instances { - if psLegacyOrderMatchesInstance(paymentType, inst) { + if psLegacyOrderMatchesInstance(orderPaymentType, inst) { matched = append(matched, inst) } } - if len(matched) == 1 { - return matched[0], nil - } - return nil, nil + return matched } func psLegacyOrderMatchesInstance(orderPaymentType string, inst *dbent.PaymentProviderInstance) bool { diff --git a/backend/internal/service/payment_webhook_provider_test.go b/backend/internal/service/payment_webhook_provider_test.go index 33e4186d..4f0b6848 100644 --- a/backend/internal/service/payment_webhook_provider_test.go +++ b/backend/internal/service/payment_webhook_provider_test.go @@ -141,6 +141,70 @@ func TestGetOrderProviderInstanceLeavesAmbiguousLegacyOrderUnresolved(t *testing require.Nil(t, got) } +func TestGetOrderProviderInstanceLeavesLegacyProviderKeyUnresolvedWhenHistoricalInstancesConflict(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-disabled-legacy"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(false). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeStripe). + SetName("stripe-enabled-current"). + SetConfig("{}"). + SetSupportedTypes("stripe"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeStripe + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeStripe, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + +func TestGetOrderProviderInstanceLeavesProviderKeyMatchUnresolvedWhenTypeNotSupported(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + _, err := client.PaymentProviderInstance.Create(). + SetProviderKey(payment.TypeWxpay). + SetName("wxpay-only"). + SetConfig("{}"). + SetSupportedTypes("wxpay"). + SetEnabled(true). + Save(ctx) + require.NoError(t, err) + + providerKey := payment.TypeWxpay + order := &dbent.PaymentOrder{ + PaymentType: payment.TypeAlipayDirect, + ProviderKey: &providerKey, + } + + svc := &PaymentService{ + entClient: client, + loadBalancer: newWebhookProviderTestLoadBalancer(client), + } + + got, err := svc.getOrderProviderInstance(ctx, order) + require.NoError(t, err) + require.Nil(t, got) +} + func TestGetWebhookProviderRejectsAmbiguousRegistryFallback(t *testing.T) { ctx := context.Background() client := newPaymentConfigServiceTestClient(t) -- GitLab From 31d0183d45b3c5d2a6692daebec58625a393fe38 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:51:57 +0800 Subject: [PATCH 060/265] fix: normalize repository email lookups --- backend/internal/repository/user_repo.go | 30 +++++++- .../user_repo_email_lookup_unit_test.go | 69 +++++++++++++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 backend/internal/repository/user_repo_email_lookup_unit_test.go diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 0c607ecc..b5efd19d 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -12,6 +12,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" dbgroup "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" "github.com/Wei-Shaw/sub2api/ent/usersubscription" @@ -104,10 +105,20 @@ func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, } func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) { - m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx) + matches, err := r.client.User.Query(). + Where(userEmailLookupPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) if err != nil { - return nil, translatePersistenceError(err, service.ErrUserNotFound, nil) + return nil, err + } + if len(matches) == 0 { + return nil, service.ErrUserNotFound } + if len(matches) > 1 { + return nil, fmt.Errorf("normalized email lookup matched multiple users for %q", strings.TrimSpace(email)) + } + m := matches[0] out := userEntityToService(m) groups, err := r.loadAllowedGroups(ctx, []int64{m.ID}) @@ -469,7 +480,20 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount } func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { - return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) + return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) +} + +func userEmailLookupPredicate(email string) predicate.User { + normalized := strings.TrimSpace(email) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.ExprP( + fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)), + normalized, + )) + }) } func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error { diff --git a/backend/internal/repository/user_repo_email_lookup_unit_test.go b/backend/internal/repository/user_repo_email_lookup_unit_test.go new file mode 100644 index 00000000..d42ce9ac --- /dev/null +++ b/backend/internal/repository/user_repo_email_lookup_unit_test.go @@ -0,0 +1,69 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newUserEntRepo(t *testing.T) (*userRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:user_repo_email_lookup?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return newUserRepositoryWithSQL(client, db), client +} + +func TestUserRepositoryGetByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + got, err := repo.GetByEmail(ctx, "legacy@example.com") + require.NoError(t, err) + require.Equal(t, " Legacy@Example.com ", got.Email) +} + +func TestUserRepositoryExistsByEmailNormalizesLegacySpacingAndCase(t *testing.T) { + repo, _ := newUserEntRepo(t) + ctx := context.Background() + + err := repo.Create(ctx, &service.User{ + Email: " Legacy@Example.com ", + Username: "legacy-user", + PasswordHash: "hash", + Role: service.RoleUser, + Status: service.StatusActive, + }) + require.NoError(t, err) + + exists, err := repo.ExistsByEmail(ctx, " LEGACY@example.com ") + require.NoError(t, err) + require.True(t, exists) +} -- GitLab From aaf4946b27ebc03163d0832b7946234f44b6b769 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 21:59:03 +0800 Subject: [PATCH 061/265] fix: normalize pending oauth email lookups --- .../handler/auth_oauth_pending_flow.go | 47 ++++++++-- .../handler/auth_oauth_pending_flow_test.go | 85 +++++++++++++++++++ 2 files changed, 126 insertions(+), 6 deletions(-) diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index fd35e4e5..1b3c1380 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -3,6 +3,7 @@ package handler import ( "context" "errors" + "fmt" "io" "net/http" "net/url" @@ -12,12 +13,14 @@ import ( "github.com/Wei-Shaw/sub2api/ent/authidentity" "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" "github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision" + "github.com/Wei-Shaw/sub2api/ent/predicate" dbuser "github.com/Wei-Shaw/sub2api/ent/user" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" + entsql "entgo.io/ent/dialect/sql" "github.com/gin-gonic/gin" ) @@ -531,11 +534,9 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, return 0, infraerrors.BadRequest("PENDING_AUTH_TARGET_USER_MISSING", "pending auth target user is missing") } - userEntity, err := client.User.Query(). - Where(dbuser.EmailEQ(email)). - Only(ctx) + userEntity, err := findUserByNormalizedEmail(ctx, client, email) if err != nil { - if dbent.IsNotFound(err) { + if errors.Is(err, service.ErrUserNotFound) { return 0, infraerrors.InternalServer("PENDING_AUTH_TARGET_USER_NOT_FOUND", "pending auth target user was not found") } return 0, err @@ -543,6 +544,40 @@ func resolvePendingOAuthTargetUserID(ctx context.Context, client *dbent.Client, return userEntity.ID, nil } +func userNormalizedEmailPredicate(email string) predicate.User { + normalized := strings.TrimSpace(email) + if normalized == "" { + return dbuser.EmailEQ(email) + } + return predicate.User(func(s *entsql.Selector) { + s.Where(entsql.ExprP( + fmt.Sprintf("LOWER(TRIM(%s)) = LOWER(TRIM(?))", s.C(dbuser.FieldEmail)), + normalized, + )) + }) +} + +func findUserByNormalizedEmail(ctx context.Context, client *dbent.Client, email string) (*dbent.User, error) { + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + matches, err := client.User.Query(). + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) + if err != nil { + return nil, err + } + if len(matches) == 0 { + return nil, service.ErrUserNotFound + } + if len(matches) > 1 { + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } + return matches[0], nil +} + func oauthIdentityIssuer(session *dbent.PendingAuthSession) *string { if session == nil { return nil @@ -1102,8 +1137,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) } email := strings.TrimSpace(strings.ToLower(req.Email)) - existingUser, err := client.User.Query().Where(dbuser.EmailEQ(email)).Only(c.Request.Context()) - if err != nil && !dbent.IsNotFound(err) { + existingUser, err := findUserByNormalizedEmail(c.Request.Context(), client, email) + if err != nil && !errors.Is(err, service.ErrUserNotFound) { response.ErrorFrom(c, infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 89accd60..8c468fdc 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -642,6 +642,60 @@ func TestCreateOIDCOAuthAccountExistingEmailReturnsAdoptExistingUserByEmailState require.Zero(t, identityCount) } +func TestCreateOIDCOAuthAccountExistingEmailNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithEmailVerification(t, false, "owner@example.com", "135790") + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("existing-email-normalized-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-existing-normalized-123"). + SetBrowserSessionKey("existing-email-normalized-browser-session-key"). + SetUpstreamIdentityClaims(map[string]any{ + "username": "oidc_user", + "suggested_display_name": "Existing OIDC User", + "suggested_avatar_url": "https://cdn.example/existing.png", + }). + SetRedirectTo("/dashboard"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + body := bytes.NewBufferString(`{"email":"owner@example.com","verify_code":"135790","password":"secret-123"}`) + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodPost, "/api/v1/auth/oauth/oidc/create-account", body) + req.Header.Set("Content-Type", "application/json") + req.AddCookie(&http.Cookie{Name: oauthPendingSessionCookieName, Value: encodeCookieValue(session.SessionToken)}) + req.AddCookie(&http.Cookie{Name: oauthPendingBrowserCookieName, Value: encodeCookieValue("existing-email-normalized-browser-session-key")}) + ginCtx.Request = req + + handler.CreateOIDCOAuthAccount(ginCtx) + + require.Equal(t, http.StatusOK, recorder.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &payload)) + require.Equal(t, "adopt_existing_user_by_email", payload["intent"]) + + storedSession, err := client.PendingAuthSession.Get(ctx, session.ID) + require.NoError(t, err) + require.NotNil(t, storedSession.TargetUserID) + require.Equal(t, existingUser.ID, *storedSession.TargetUserID) + require.Equal(t, "owner@example.com", storedSession.ResolvedEmail) +} + func TestBindOIDCOAuthLoginBindsExistingUserAndConsumesSession(t *testing.T) { handler, client := newOAuthPendingFlowTestHandler(t, false) ctx := context.Background() @@ -884,6 +938,37 @@ func TestBindOIDCOAuthLoginAppliesFirstBindGrantOnce(t *testing.T) { require.Equal(t, 1, countProviderGrantRecords(t, client, existingUser.ID, "oidc", "first_bind")) } +func TestResolvePendingOAuthTargetUserIDNormalizesLegacySpacingAndCase(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + _ = handler + ctx := context.Background() + + existingUser, err := client.User.Create(). + SetEmail(" Owner@Example.com "). + SetUsername("owner-user"). + SetPasswordHash("hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + + session, err := client.PendingAuthSession.Create(). + SetSessionToken("resolve-target-session-token"). + SetIntent("login"). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("oidc-target-123"). + SetResolvedEmail("owner@example.com"). + SetBrowserSessionKey("resolve-target-browser-session-key"). + SetExpiresAt(time.Now().UTC().Add(10 * time.Minute)). + Save(ctx) + require.NoError(t, err) + + resolvedUserID, err := resolvePendingOAuthTargetUserID(ctx, client, session) + require.NoError(t, err) + require.Equal(t, existingUser.ID, resolvedUserID) +} + func TestBindOIDCOAuthLoginReturns2FAChallengeWhenUserHasTotp(t *testing.T) { totpCache := &oauthPendingFlowTotpCacheStub{} handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ -- GitLab From bbc4aed3d91c8b3312c320843e3d4bea54e17d80 Mon Sep 17 00:00:00 2001 From: erio Date: Mon, 20 Apr 2026 22:01:09 +0800 Subject: [PATCH 062/265] =?UTF-8?q?fix(openai):=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E5=B7=B2=E4=B8=8B=E7=BA=BF=20Codex=20=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=B9=B6=E4=BF=AE=E5=A4=8D=E5=BD=92=E4=B8=80=E5=8C=96=E5=85=9C?= =?UTF-8?q?=E5=BA=95=E5=89=AF=E4=BD=9C=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - backend: 删除 gpt-5 / 5.1 / 5.1-codex / 5.1-codex-max / 5.1-codex-mini / 5.2-codex / 5.4-nano 的内置映射与 DefaultModels 条目 - backend: normalizeCodexModel 默认兜底由 gpt-5.1 改为 gpt-5.4,gpt-5.3-codex-spark 独立保留映射 - backend: 修复 isOpenAIGPT54Model 与 shouldAutoInjectPromptCacheKeyForCompat 对 claude / gpt-4o 的误判(之前依赖 gpt-5.1 作为非 GPT 族的隐式 sentinel,改后需要显式前缀守卫) - backend: 清理 billing_service 中已不可达的 fallback 价格与 switch 分支 - frontend: 从白名单、OpenCode 配置、预设映射中移除已下线模型 - 同步更新所有相关单测 Refs: #1758, parallels upstream #1759 but adds downstream guard fixes --- backend/internal/pkg/openai/constants.go | 9 +- backend/internal/service/billing_service.go | 51 ++-------- .../internal/service/billing_service_test.go | 29 +----- .../service/openai_codex_transform.go | 76 +++------------ .../service/openai_codex_transform_test.go | 30 ++++-- .../service/openai_compat_prompt_cache_key.go | 10 +- .../service/openai_model_mapping_test.go | 14 +-- frontend/src/components/keys/UseKeyModal.vue | 92 ------------------- .../keys/__tests__/UseKeyModal.spec.ts | 4 +- .../__tests__/useModelWhitelist.spec.ts | 19 +++- frontend/src/composables/useModelWhitelist.ts | 15 +-- 11 files changed, 84 insertions(+), 265 deletions(-) diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index 49e38bf8..980f058d 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -17,16 +17,9 @@ type Model struct { var DefaultModels = []Model{ {ID: "gpt-5.4", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4"}, {ID: "gpt-5.4-mini", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Mini"}, - {ID: "gpt-5.4-nano", Object: "model", Created: 1738368000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.4 Nano"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, - {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, - {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, - {ID: "gpt-5.1-codex", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex"}, - {ID: "gpt-5.1", Object: "model", Created: 1731456000, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1"}, - {ID: "gpt-5.1-codex-mini", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Mini"}, - {ID: "gpt-5", Object: "model", Created: 1722988800, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5"}, } // DefaultModelIDs returns the default model ID list @@ -39,7 +32,7 @@ func DefaultModelIDs() []string { } // DefaultTestModel default model for testing OpenAI accounts -const DefaultTestModel = "gpt-5.1-codex" +const DefaultTestModel = "gpt-5.4" // DefaultInstructions default instructions for non-Codex CLI requests // Content loaded from instructions.txt at compile time diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index c9f32b3b..a45203a3 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -203,17 +203,6 @@ func (s *BillingService) initFallbackPricing() { SupportsCacheBreakdown: false, } - // OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费) - s.fallbackPrices["gpt-5.1"] = &ModelPricing{ - InputPricePerToken: 1.25e-6, // $1.25 per MTok - InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok - OutputPricePerToken: 10e-6, // $10 per MTok - OutputPricePerTokenPriority: 20e-6, // $20 per MTok - CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok - CacheReadPricePerToken: 0.125e-6, - CacheReadPricePerTokenPriority: 0.25e-6, - SupportsCacheBreakdown: false, - } // OpenAI GPT-5.4(业务指定价格) s.fallbackPrices["gpt-5.4"] = &ModelPricing{ InputPricePerToken: 2.5e-6, // $2.5 per MTok @@ -234,12 +223,6 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerToken: 7.5e-8, SupportsCacheBreakdown: false, } - s.fallbackPrices["gpt-5.4-nano"] = &ModelPricing{ - InputPricePerToken: 2e-7, - OutputPricePerToken: 1.25e-6, - CacheReadPricePerToken: 2e-8, - SupportsCacheBreakdown: false, - } // OpenAI GPT-5.2(本地兜底) s.fallbackPrices["gpt-5.2"] = &ModelPricing{ InputPricePerToken: 1.75e-6, @@ -251,8 +234,8 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerTokenPriority: 0.35e-6, SupportsCacheBreakdown: false, } - // Codex 族兜底统一按 GPT-5.1 Codex 价格计费 - s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{ + // Codex 族兜底统一按 GPT-5.3 Codex 价格计费 + s.fallbackPrices["gpt-5.3-codex"] = &ModelPricing{ InputPricePerToken: 1.5e-6, // $1.5 per MTok InputPricePerTokenPriority: 3e-6, // $3 per MTok OutputPricePerToken: 12e-6, // $12 per MTok @@ -262,17 +245,6 @@ func (s *BillingService) initFallbackPricing() { CacheReadPricePerTokenPriority: 0.3e-6, SupportsCacheBreakdown: false, } - s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{ - InputPricePerToken: 1.75e-6, - InputPricePerTokenPriority: 3.5e-6, - OutputPricePerToken: 14e-6, - OutputPricePerTokenPriority: 28e-6, - CacheCreationPricePerToken: 1.75e-6, - CacheReadPricePerToken: 0.175e-6, - CacheReadPricePerTokenPriority: 0.35e-6, - SupportsCacheBreakdown: false, - } - s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"] } // getFallbackPricing 根据模型系列获取回退价格 @@ -318,20 +290,12 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { switch normalized { case "gpt-5.4-mini": return s.fallbackPrices["gpt-5.4-mini"] - case "gpt-5.4-nano": - return s.fallbackPrices["gpt-5.4-nano"] case "gpt-5.4": return s.fallbackPrices["gpt-5.4"] case "gpt-5.2": return s.fallbackPrices["gpt-5.2"] - case "gpt-5.2-codex": - return s.fallbackPrices["gpt-5.2-codex"] - case "gpt-5.3-codex": + case "gpt-5.3-codex", "gpt-5.3-codex-spark": return s.fallbackPrices["gpt-5.3-codex"] - case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest": - return s.fallbackPrices["gpt-5.1-codex"] - case "gpt-5.1": - return s.fallbackPrices["gpt-5.1"] } } @@ -667,8 +631,13 @@ func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens } func isOpenAIGPT54Model(model string) bool { - normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model))) - return normalized == "gpt-5.4" + trimmed := strings.TrimSpace(strings.ToLower(model)) + // 仅当模型字符串实际属于 GPT-5/Codex 族时才做归一判定,避免 normalizeCodexModel + // 的默认兜底把非 OpenAI 模型(claude-*、gemini-*、gpt-4o)误识别为 gpt-5.4。 + if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { + return false + } + return normalizeCodexModel(trimmed) == "gpt-5.4" } // CalculateCostWithConfig 使用配置中的默认倍率计算费用 diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go index fc8361c7..222abd69 100644 --- a/backend/internal/service/billing_service_test.go +++ b/backend/internal/service/billing_service_test.go @@ -123,15 +123,6 @@ func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) { require.Contains(t, err.Error(), "pricing not found") } -func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) { - svc := newTestBillingService() - - pricing, err := svc.GetModelPricing("gpt-5.1") - require.NoError(t, err) - require.NotNil(t, pricing) - require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12) -} - func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) { svc := newTestBillingService() @@ -158,18 +149,6 @@ func TestGetModelPricing_OpenAIGPT54MiniFallback(t *testing.T) { require.Zero(t, pricing.LongContextInputThreshold) } -func TestGetModelPricing_OpenAIGPT54NanoFallback(t *testing.T) { - svc := newTestBillingService() - - pricing, err := svc.GetModelPricing("gpt-5.4-nano") - require.NoError(t, err) - require.NotNil(t, pricing) - require.InDelta(t, 2e-7, pricing.InputPricePerToken, 1e-12) - require.InDelta(t, 1.25e-6, pricing.OutputPricePerToken, 1e-12) - require.InDelta(t, 2e-8, pricing.CacheReadPricePerToken, 1e-12) - require.Zero(t, pricing.LongContextInputThreshold) -} - func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) { svc := newTestBillingService() @@ -204,13 +183,13 @@ func TestGetFallbackPricing_FamilyMatching(t *testing.T) { {name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6}, {name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6}, {name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true}, - {name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6}, {name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6}, {name: "openai gpt5.4 mini", model: "gpt-5.4-mini", expectedInput: 7.5e-7}, - {name: "openai gpt5.4 nano", model: "gpt-5.4-nano", expectedInput: 2e-7}, {name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6}, - {name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6}, - {name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6}, + {name: "openai gpt5.3 codex spark", model: "gpt-5.3-codex-spark", expectedInput: 1.5e-6}, + {name: "openai legacy gpt5.1 falls back to gpt5.4", model: "gpt-5.1", expectedInput: 2.5e-6}, + {name: "openai legacy gpt5.1 codex falls back to gpt5.3 codex", model: "gpt-5.1-codex", expectedInput: 1.5e-6}, + {name: "openai legacy codex mini latest falls back to gpt5.3 codex", model: "codex-mini-latest", expectedInput: 1.5e-6}, {name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true}, {name: "non supported family", model: "qwen-max", expectNilPricing: true}, } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index a266d6a0..457309d3 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -8,7 +8,6 @@ import ( var codexModelMap = map[string]string{ "gpt-5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", - "gpt-5.4-nano": "gpt-5.4-nano", "gpt-5.4-none": "gpt-5.4", "gpt-5.4-low": "gpt-5.4", "gpt-5.4-medium": "gpt-5.4", @@ -22,52 +21,21 @@ var codexModelMap = map[string]string{ "gpt-5.3-high": "gpt-5.3-codex", "gpt-5.3-xhigh": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-low": "gpt-5.3-codex", - "gpt-5.3-codex-spark-medium": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-low": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-medium": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3-codex-low": "gpt-5.3-codex", "gpt-5.3-codex-medium": "gpt-5.3-codex", "gpt-5.3-codex-high": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-low": "gpt-5.1-codex", - "gpt-5.1-codex-medium": "gpt-5.1-codex", - "gpt-5.1-codex-high": "gpt-5.1-codex", - "gpt-5.1-codex-max": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", "gpt-5.2": "gpt-5.2", "gpt-5.2-none": "gpt-5.2", "gpt-5.2-low": "gpt-5.2", "gpt-5.2-medium": "gpt-5.2", "gpt-5.2-high": "gpt-5.2", "gpt-5.2-xhigh": "gpt-5.2", - "gpt-5.2-codex": "gpt-5.2-codex", - "gpt-5.2-codex-low": "gpt-5.2-codex", - "gpt-5.2-codex-medium": "gpt-5.2-codex", - "gpt-5.2-codex-high": "gpt-5.2-codex", - "gpt-5.2-codex-xhigh": "gpt-5.2-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-none": "gpt-5.1", - "gpt-5.1-low": "gpt-5.1", - "gpt-5.1-medium": "gpt-5.1", - "gpt-5.1-high": "gpt-5.1", - "gpt-5.1-chat-latest": "gpt-5.1", - "gpt-5-codex": "gpt-5.1-codex", - "codex-mini-latest": "gpt-5.1-codex-mini", - "gpt-5-codex-mini": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5": "gpt-5.1", - "gpt-5-mini": "gpt-5.1", - "gpt-5-nano": "gpt-5.1", } type codexTransformResult struct { @@ -220,7 +188,7 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact func normalizeCodexModel(model string) string { if model == "" { - return "gpt-5.1" + return "gpt-5.4" } modelID := model @@ -238,49 +206,29 @@ func normalizeCodexModel(model string) string { if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { return "gpt-5.4-mini" } - if strings.Contains(normalized, "gpt-5.4-nano") || strings.Contains(normalized, "gpt 5.4 nano") { - return "gpt-5.4-nano" - } if strings.Contains(normalized, "gpt-5.4") || strings.Contains(normalized, "gpt 5.4") { return "gpt-5.4" } - if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { - return "gpt-5.2-codex" - } if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { return "gpt-5.2" } + if strings.Contains(normalized, "gpt-5.3-codex-spark") || strings.Contains(normalized, "gpt 5.3 codex spark") { + return "gpt-5.3-codex-spark" + } if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") { return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { return "gpt-5.3-codex" } - if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { - return "gpt-5.1-codex-max" - } - if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") { - return "gpt-5.1-codex-mini" - } - if strings.Contains(normalized, "codex-mini-latest") || - strings.Contains(normalized, "gpt-5-codex-mini") || - strings.Contains(normalized, "gpt 5 codex mini") { - return "codex-mini-latest" - } - if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") { - return "gpt-5.1-codex" - } - if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") { - return "gpt-5.1" - } if strings.Contains(normalized, "codex") { - return "gpt-5.1-codex" + return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { - return "gpt-5.1" + return "gpt-5.4" } - return "gpt-5.1" + return "gpt-5.4" } func normalizeOpenAIModelForUpstream(account *Account, model string) string { diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 993ade07..22264f5e 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -240,15 +240,13 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt 5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", "gpt 5.4 mini": "gpt-5.4-mini", - "gpt-5.4-nano": "gpt-5.4-nano", - "gpt 5.4 nano": "gpt-5.4-nano", "gpt-5.3": "gpt-5.3-codex", "gpt-5.3-codex": "gpt-5.3-codex", "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt 5.3 codex spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt 5.3 codex spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt 5.3 codex": "gpt-5.3-codex", } @@ -257,6 +255,26 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { } } +func TestNormalizeCodexModel_RemovedModelsFallbackToSupportedTargets(t *testing.T) { + cases := map[string]string{ + "": "gpt-5.4", + "gpt-5": "gpt-5.4", + "gpt-5-mini": "gpt-5.4", + "gpt-5-nano": "gpt-5.4", + "gpt-5.1": "gpt-5.4", + "gpt-5.1-codex": "gpt-5.3-codex", + "gpt-5.1-codex-max": "gpt-5.3-codex", + "gpt-5.1-codex-mini": "gpt-5.3-codex", + "gpt-5.2-codex": "gpt-5.2", + "codex-mini-latest": "gpt-5.3-codex", + "gpt-5-codex": "gpt-5.3-codex", + } + + for input, expected := range cases { + require.Equal(t, expected, normalizeCodexModel(input)) + } +} + func TestApplyCodexOAuthTransform_PreservesBareSparkModel(t *testing.T) { reqBody := map[string]any{ "model": "gpt-5.3-codex-spark", diff --git a/backend/internal/service/openai_compat_prompt_cache_key.go b/backend/internal/service/openai_compat_prompt_cache_key.go index 88e16a4d..fcd27f19 100644 --- a/backend/internal/service/openai_compat_prompt_cache_key.go +++ b/backend/internal/service/openai_compat_prompt_cache_key.go @@ -10,8 +10,14 @@ import ( const compatPromptCacheKeyPrefix = "compat_cc_" func shouldAutoInjectPromptCacheKeyForCompat(model string) bool { - switch normalizeCodexModel(strings.TrimSpace(model)) { - case "gpt-5.4", "gpt-5.3-codex": + trimmed := strings.TrimSpace(strings.ToLower(model)) + // 仅对 Codex OAuth 路径支持的 GPT-5 族开启自动注入,避免 normalizeCodexModel + // 的默认兜底把任意模型(如 gpt-4o、claude-*)误判为 gpt-5.4。 + if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { + return false + } + switch normalizeCodexModel(trimmed) { + case "gpt-5.4", "gpt-5.3-codex", "gpt-5.3-codex-spark": return true default: return false diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index cda7e369..35e7c250 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -69,14 +69,14 @@ func TestResolveOpenAIForwardModel(t *testing.T) { } } -func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t *testing.T) { +func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt54(t *testing.T) { account := &Account{ Credentials: map[string]any{}, } withoutDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "")) - if withoutDefault != "gpt-5.1" { - t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.1") + if withoutDefault != "gpt-5.4" { + t.Fatalf("normalizeCodexModel(...) = %q, want %q", withoutDefault, "gpt-5.4") } withDefault := normalizeCodexModel(resolveOpenAIForwardModel(account, "claude-opus-4-6", "gpt-5.4")) @@ -87,9 +87,9 @@ func TestResolveOpenAIForwardModel_PreventsClaudeModelFromFallingBackToGpt51(t * func TestNormalizeCodexModel(t *testing.T) { cases := map[string]string{ - "gpt-5.3-codex-spark": "gpt-5.3-codex", - "gpt-5.3-codex-spark-high": "gpt-5.3-codex", - "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3": "gpt-5.3-codex", } @@ -111,7 +111,7 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { name: "oauth keeps codex normalization behavior", account: &Account{Type: AccountTypeOAuth}, model: "gemini-3-flash-preview", - want: "gpt-5.1", + want: "gpt-5.4", }, { name: "apikey preserves custom compatible model", diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 7770e658..b3679107 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -617,66 +617,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin } } const openaiModels = { - 'gpt-5-codex': { - name: 'GPT-5 Codex', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, - 'gpt-5.1-codex': { - name: 'GPT-5.1 Codex', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, - 'gpt-5.1-codex-max': { - name: 'GPT-5.1 Codex Max', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, - 'gpt-5.1-codex-mini': { - name: 'GPT-5.1 Codex Mini', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {} - } - }, 'gpt-5.2': { name: 'GPT-5.2', limit: { @@ -725,22 +665,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin xhigh: {} } }, - 'gpt-5.4-nano': { - name: 'GPT-5.4 Nano', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {}, - xhigh: {} - } - }, 'gpt-5.3-codex-spark': { name: 'GPT-5.3 Codex Spark', limit: { @@ -773,22 +697,6 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin xhigh: {} } }, - 'gpt-5.2-codex': { - name: 'GPT-5.2 Codex', - limit: { - context: 400000, - output: 128000 - }, - options: { - store: false - }, - variants: { - low: {}, - medium: {}, - high: {}, - xhigh: {} - } - }, 'codex-mini-latest': { name: 'Codex Mini', limit: { diff --git a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts index 98b5dede..f7db586a 100644 --- a/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts +++ b/frontend/src/components/keys/__tests__/UseKeyModal.spec.ts @@ -17,7 +17,7 @@ vi.mock('@/composables/useClipboard', () => ({ import UseKeyModal from '../UseKeyModal.vue' describe('UseKeyModal', () => { - it('renders updated GPT-5.4 mini/nano names in OpenCode config', async () => { + it('renders GPT-5.4 mini entry in OpenCode config', async () => { const wrapper = mount(UseKeyModal, { props: { show: true, @@ -48,6 +48,6 @@ describe('UseKeyModal', () => { const codeBlock = wrapper.find('pre code') expect(codeBlock.exists()).toBe(true) expect(codeBlock.text()).toContain('"name": "GPT-5.4 Mini"') - expect(codeBlock.text()).toContain('"name": "GPT-5.4 Nano"') + expect(codeBlock.text()).not.toContain('"name": "GPT-5.4 Nano"') }) }) diff --git a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts index 4061be4d..d35e3b12 100644 --- a/frontend/src/composables/__tests__/useModelWhitelist.spec.ts +++ b/frontend/src/composables/__tests__/useModelWhitelist.spec.ts @@ -12,10 +12,20 @@ describe('useModelWhitelist', () => { expect(models).toContain('gpt-5.4') expect(models).toContain('gpt-5.4-mini') - expect(models).toContain('gpt-5.4-nano') expect(models).toContain('gpt-5.4-2026-03-05') }) + it('openai 模型列表不再暴露已下线的 ChatGPT 登录 Codex 模型', () => { + const models = getModelsByPlatform('openai') + + expect(models).not.toContain('gpt-5') + expect(models).not.toContain('gpt-5.1') + expect(models).not.toContain('gpt-5.1-codex') + expect(models).not.toContain('gpt-5.1-codex-max') + expect(models).not.toContain('gpt-5.1-codex-mini') + expect(models).not.toContain('gpt-5.2-codex') + }) + it('antigravity 模型列表包含图片模型兼容项', () => { const models = getModelsByPlatform('antigravity') @@ -55,12 +65,11 @@ describe('useModelWhitelist', () => { }) }) - it('whitelist keeps GPT-5.4 mini and nano exact mappings', () => { - const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-mini', 'gpt-5.4-nano'], []) + it('whitelist keeps GPT-5.4 mini exact mappings', () => { + const mapping = buildModelMappingObject('whitelist', ['gpt-5.4-mini'], []) expect(mapping).toEqual({ - 'gpt-5.4-mini': 'gpt-5.4-mini', - 'gpt-5.4-nano': 'gpt-5.4-nano' + 'gpt-5.4-mini': 'gpt-5.4-mini' }) }) }) diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts index a282ae7d..ddd7af48 100644 --- a/frontend/src/composables/useModelWhitelist.ts +++ b/frontend/src/composables/useModelWhitelist.ts @@ -13,19 +13,11 @@ const openaiModels = [ 'o1', 'o1-preview', 'o1-mini', 'o1-pro', 'o3', 'o3-mini', 'o3-pro', 'o4-mini', - // GPT-5 系列(同步后端定价文件) - 'gpt-5', 'gpt-5-2025-08-07', 'gpt-5-chat', 'gpt-5-chat-latest', - 'gpt-5-codex', 'gpt-5.3-codex-spark', 'gpt-5-pro', 'gpt-5-pro-2025-10-06', - 'gpt-5-mini', 'gpt-5-mini-2025-08-07', - 'gpt-5-nano', 'gpt-5-nano-2025-08-07', - // GPT-5.1 系列 - 'gpt-5.1', 'gpt-5.1-2025-11-13', 'gpt-5.1-chat-latest', - 'gpt-5.1-codex', 'gpt-5.1-codex-max', 'gpt-5.1-codex-mini', // GPT-5.2 系列 'gpt-5.2', 'gpt-5.2-2025-12-11', 'gpt-5.2-chat-latest', - 'gpt-5.2-codex', 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11', + 'gpt-5.2-pro', 'gpt-5.2-pro-2025-12-11', // GPT-5.4 系列 - 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-nano', 'gpt-5.4-2026-03-05', + 'gpt-5.4', 'gpt-5.4-mini', 'gpt-5.4-2026-03-05', // GPT-5.3 系列 'gpt-5.3-codex', 'gpt-5.3-codex-spark', 'chatgpt-4o-latest', @@ -264,12 +256,9 @@ const openaiPresetMappings = [ { label: 'GPT-4.1', from: 'gpt-4.1', to: 'gpt-4.1', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' }, { label: 'o1', from: 'o1', to: 'o1', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'o3', from: 'o3', to: 'o3', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' }, - { label: 'GPT-5', from: 'gpt-5', to: 'gpt-5', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }, { label: 'GPT-5.3 Codex Spark', from: 'gpt-5.3-codex-spark', to: 'gpt-5.3-codex-spark', color: 'bg-teal-100 text-teal-700 hover:bg-teal-200 dark:bg-teal-900/30 dark:text-teal-400' }, - { label: 'GPT-5.1', from: 'gpt-5.1', to: 'gpt-5.1', color: 'bg-orange-100 text-orange-700 hover:bg-orange-200 dark:bg-orange-900/30 dark:text-orange-400' }, { label: 'GPT-5.2', from: 'gpt-5.2', to: 'gpt-5.2', color: 'bg-red-100 text-red-700 hover:bg-red-200 dark:bg-red-900/30 dark:text-red-400' }, { label: 'GPT-5.4', from: 'gpt-5.4', to: 'gpt-5.4', color: 'bg-rose-100 text-rose-700 hover:bg-rose-200 dark:bg-rose-900/30 dark:text-rose-400' }, - { label: 'GPT-5.1 Codex', from: 'gpt-5.1-codex', to: 'gpt-5.1-codex', color: 'bg-cyan-100 text-cyan-700 hover:bg-cyan-200 dark:bg-cyan-900/30 dark:text-cyan-400' }, { label: 'Haiku→5.4', from: 'claude-haiku-4-5-20251001', to: 'gpt-5.4', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' }, { label: 'Opus→5.4', from: 'claude-opus-4-6', to: 'gpt-5.4', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }, { label: 'Sonnet→5.4', from: 'claude-sonnet-4-6', to: 'gpt-5.4', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' } -- GitLab From 3bd3027251dd5e008e60ba7299ef63dbc8f8a32c Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:05:33 +0800 Subject: [PATCH 063/265] feat: expose auth identity migration reports --- .../admin/admin_basic_handlers_test.go | 12 ++ .../handler/admin/admin_service_stub_test.go | 34 +++++ .../internal/handler/admin/user_handler.go | 25 ++++ backend/internal/server/routes/admin.go | 2 + backend/internal/service/admin_service.go | 135 ++++++++++++++++++ ..._service_identity_migration_report_test.go | 92 ++++++++++++ 6 files changed, 300 insertions(+) create mode 100644 backend/internal/service/admin_service_identity_migration_report_test.go diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index cba3ae21..ff9eec7e 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -22,6 +22,8 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { redeemHandler := NewRedeemHandler(adminSvc, nil) router.GET("/api/v1/admin/users", userHandler.List) + router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) + router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) router.POST("/api/v1/admin/users", userHandler.Create) router.PUT("/api/v1/admin/users/:id", userHandler.Update) @@ -70,6 +72,16 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports/summary", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/auth-identity-migration-reports?report_type=oidc_synthetic_email_requires_manual_recovery", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil) router.ServeHTTP(rec, req) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 6d1ef1b6..681c25c6 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -17,6 +17,7 @@ type stubAdminService struct { proxies []service.Proxy proxyCounts []service.ProxyWithAccountCount redeems []service.RedeemCode + migrationReports []service.AuthIdentityMigrationReport createdAccounts []*service.CreateAccountInput createdProxies []*service.CreateProxyInput updatedProxyIDs []int64 @@ -123,6 +124,15 @@ func newStubAdminService() *stubAdminService { proxies: []service.Proxy{proxy}, proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}}, redeems: []service.RedeemCode{redeem}, + migrationReports: []service.AuthIdentityMigrationReport{ + { + ID: 1, + ReportType: "oidc_synthetic_email_requires_manual_recovery", + ReportKey: "u-1", + Details: map[string]any{"user_id": 1}, + CreatedAt: now, + }, + }, } } @@ -167,6 +177,30 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, return map[string]any{"user_id": userID}, nil } +func (s *stubAdminService) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]service.AuthIdentityMigrationReport, int64, error) { + if reportType == "" { + return s.migrationReports, int64(len(s.migrationReports)), nil + } + filtered := make([]service.AuthIdentityMigrationReport, 0, len(s.migrationReports)) + for _, report := range s.migrationReports { + if strings.EqualFold(report.ReportType, reportType) { + filtered = append(filtered, report) + } + } + return filtered, int64(len(filtered)), nil +} + +func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*service.AuthIdentityMigrationReportSummary, error) { + summary := &service.AuthIdentityMigrationReportSummary{ + ByType: map[string]int64{}, + } + for _, report := range s.migrationReports { + summary.Total++ + summary.ByType[report.ReportType]++ + } + return summary, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 1453bd07..ee3fbb1e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -172,6 +172,31 @@ func (h *UserHandler) GetByID(c *gin.Context) { response.Success(c, dto.UserFromServiceAdmin(user)) } +// GetAuthIdentityMigrationReportSummary returns aggregate migration report counts. +// GET /api/v1/admin/users/auth-identity-migration-reports/summary +func (h *UserHandler) GetAuthIdentityMigrationReportSummary(c *gin.Context) { + summary, err := h.adminService.GetAuthIdentityMigrationReportSummary(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, summary) +} + +// ListAuthIdentityMigrationReports returns paginated auth identity migration reports. +// GET /api/v1/admin/users/auth-identity-migration-reports +func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + reportType := strings.TrimSpace(c.Query("report_type")) + + reports, total, err := h.adminService.ListAuthIdentityMigrationReports(c.Request.Context(), reportType, page, pageSize) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, reports, total, page, pageSize) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 9af0fd8e..0b5aaf09 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -210,6 +210,8 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users := admin.Group("/users") { + users.GET("/auth-identity-migration-reports/summary", h.Admin.User.GetAuthIdentityMigrationReportSummary) + users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) users.POST("", h.Admin.User.Create) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7c26a47c..972681a5 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,6 +2,8 @@ package service import ( "context" + "database/sql" + "encoding/json" "errors" "fmt" "io" @@ -16,6 +18,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/util/httputil" + + entsql "entgo.io/ent/dialect/sql" ) // AdminService interface defines admin management operations @@ -33,6 +37,8 @@ type AdminService interface { // codeType is optional - pass empty string to return all types. // Also returns totalRecharged (sum of all positive balance top-ups). GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) + ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) + GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -127,6 +133,19 @@ type UpdateUserInput struct { GroupRates map[int64]*float64 } +type AuthIdentityMigrationReport struct { + ID int64 `json:"id"` + ReportType string `json:"report_type"` + ReportKey string `json:"report_key"` + Details map[string]any `json:"details"` + CreatedAt time.Time `json:"created_at"` +} + +type AuthIdentityMigrationReportSummary struct { + Total int64 `json:"total"` + ByType map[string]int64 `json:"by_type"` +} + type CreateGroupInput struct { Name string Description string @@ -788,6 +807,122 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int return codes, result.Total, totalRecharged, nil } +func (s *adminServiceImpl) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) { + db, err := s.adminSQLDB() + if err != nil { + return nil, 0, err + } + + reportType = strings.TrimSpace(reportType) + if page <= 0 { + page = 1 + } + if pageSize <= 0 { + pageSize = 20 + } + offset := (page - 1) * pageSize + + var total int64 + if err := db.QueryRowContext(ctx, ` +SELECT COUNT(*) +FROM auth_identity_migration_reports +WHERE ($1 = '' OR report_type = $1)`, + reportType, + ).Scan(&total); err != nil { + return nil, 0, err + } + + rows, err := db.QueryContext(ctx, ` +SELECT id, report_type, report_key, details, created_at +FROM auth_identity_migration_reports +WHERE ($1 = '' OR report_type = $1) +ORDER BY created_at DESC, id DESC +LIMIT $2 OFFSET $3`, + reportType, + pageSize, + offset, + ) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + + reports := make([]AuthIdentityMigrationReport, 0) + for rows.Next() { + report, scanErr := scanAuthIdentityMigrationReport(rows) + if scanErr != nil { + return nil, 0, scanErr + } + reports = append(reports, report) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + return reports, total, nil +} + +func (s *adminServiceImpl) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) { + db, err := s.adminSQLDB() + if err != nil { + return nil, err + } + + rows, err := db.QueryContext(ctx, ` +SELECT report_type, COUNT(*) +FROM auth_identity_migration_reports +GROUP BY report_type +ORDER BY report_type ASC`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + summary := &AuthIdentityMigrationReportSummary{ + ByType: make(map[string]int64), + } + for rows.Next() { + var reportType string + var count int64 + if err := rows.Scan(&reportType, &count); err != nil { + return nil, err + } + summary.ByType[reportType] = count + summary.Total += count + } + if err := rows.Err(); err != nil { + return nil, err + } + return summary, nil +} + +func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { + if s == nil || s.entClient == nil { + return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") + } + driver, ok := s.entClient.Driver().(*entsql.Driver) + if !ok || driver.DB() == nil { + return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") + } + return driver.DB(), nil +} + +func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { + var ( + report AuthIdentityMigrationReport + details []byte + ) + if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil { + return AuthIdentityMigrationReport{}, err + } + report.Details = map[string]any{} + if len(details) > 0 { + if err := json.Unmarshal(details, &report.Details); err != nil { + return AuthIdentityMigrationReport{}, err + } + } + return report, nil +} + // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} diff --git a/backend/internal/service/admin_service_identity_migration_report_test.go b/backend/internal/service/admin_service_identity_migration_report_test.go new file mode 100644 index 00000000..75ca3e5a --- /dev/null +++ b/backend/internal/service/admin_service_identity_migration_report_test.go @@ -0,0 +1,92 @@ +package service + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAdminServiceMigrationReportTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:admin_service_migration_reports?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + _, err = db.Exec(`CREATE TABLE auth_identity_migration_reports ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + report_type TEXT NOT NULL, + report_key TEXT NOT NULL, + details TEXT NOT NULL DEFAULT '{}', + created_at DATETIME NOT NULL + )`) + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestAdminServiceListAuthIdentityMigrationReports(t *testing.T) { + client := newAdminServiceMigrationReportTestClient(t) + driver, ok := client.Driver().(*entsql.Driver) + require.True(t, ok) + + now := time.Now().UTC() + _, err := driver.DB().ExecContext(context.Background(), ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES + ($1, $2, $3, $4), + ($5, $6, $7, $8)`, + "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now, + "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute), + ) + require.NoError(t, err) + + svc := &adminServiceImpl{entClient: client} + reports, total, err := svc.ListAuthIdentityMigrationReports(context.Background(), "oidc_synthetic_email_requires_manual_recovery", 1, 20) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Len(t, reports, 1) + require.Equal(t, "oidc_synthetic_email_requires_manual_recovery", reports[0].ReportType) + require.Equal(t, float64(1), reports[0].Details["user_id"]) +} + +func TestAdminServiceGetAuthIdentityMigrationReportSummary(t *testing.T) { + client := newAdminServiceMigrationReportTestClient(t) + driver, ok := client.Driver().(*entsql.Driver) + require.True(t, ok) + + now := time.Now().UTC() + _, err := driver.DB().ExecContext(context.Background(), ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES + ($1, $2, $3, $4), + ($5, $6, $7, $8), + ($9, $10, $11, $12)`, + "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now, + "wechat_provider_key_conflict", "u-2", `{"user_id":2}`, now.Add(-time.Minute), + "wechat_provider_key_conflict", "u-3", `{"user_id":3}`, now.Add(-2*time.Minute), + ) + require.NoError(t, err) + + svc := &adminServiceImpl{entClient: client} + summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(3), summary.Total) + require.Equal(t, int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"]) + require.Equal(t, int64(2), summary.ByType["wechat_provider_key_conflict"]) +} -- GitLab From 452e55a53c7a22c3e4a0da2421d19cfb6819de96 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:22:14 +0800 Subject: [PATCH 064/265] feat: add admin auth identity repair binding --- .../admin/admin_basic_handlers_test.go | 48 +++- .../handler/admin/admin_service_stub_test.go | 48 ++++ .../internal/handler/admin/user_handler.go | 55 ++++ backend/internal/server/routes/admin.go | 1 + backend/internal/service/admin_service.go | 262 ++++++++++++++++++ ...dmin_service_auth_identity_binding_test.go | 215 ++++++++++++++ 6 files changed, 628 insertions(+), 1 deletion(-) create mode 100644 backend/internal/service/admin_service_auth_identity_binding_test.go diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index ff9eec7e..57620005 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -25,6 +25,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) + router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity) router.POST("/api/v1/admin/users", userHandler.Create) router.PUT("/api/v1/admin/users/:id", userHandler.Update) router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) @@ -87,8 +88,26 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + bindBody := map[string]any{ + "provider_type": "wechat", + "provider_key": "wechat-main", + "provider_subject": "union-123", + "metadata": map[string]any{"source": "admin-repair"}, + "channel": map[string]any{ + "channel": "open", + "channel_app_id": "wx-open", + "channel_subject": "openid-123", + }, + } + body, _ := json.Marshal(bindBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} - body, _ := json.Marshal(createBody) + body, _ = json.Marshal(createBody) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") @@ -125,6 +144,33 @@ func TestUserHandlerEndpoints(t *testing.T) { require.Equal(t, http.StatusOK, rec.Code) } +func TestUserHandlerBindAuthIdentityMapsRequest(t *testing.T) { + router, adminSvc := setupAdminRouter() + + body, err := json.Marshal(map[string]any{ + "provider_type": "oidc", + "provider_key": "https://issuer.example", + "provider_subject": "subject-123", + "issuer": "https://issuer.example", + "metadata": map[string]any{"report_id": 12}, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/9/auth-identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(9), adminSvc.boundAuthIdentityFor) + require.NotNil(t, adminSvc.boundAuthIdentity) + require.Equal(t, "oidc", adminSvc.boundAuthIdentity.ProviderType) + require.Equal(t, "https://issuer.example", adminSvc.boundAuthIdentity.ProviderKey) + require.Equal(t, "subject-123", adminSvc.boundAuthIdentity.ProviderSubject) + require.Nil(t, adminSvc.boundAuthIdentity.Channel) + require.Equal(t, float64(12), adminSvc.boundAuthIdentity.Metadata["report_id"]) +} + func TestGroupHandlerEndpoints(t *testing.T) { router, _ := setupAdminRouter() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 681c25c6..c8c7a247 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -18,6 +18,8 @@ type stubAdminService struct { proxyCounts []service.ProxyWithAccountCount redeems []service.RedeemCode migrationReports []service.AuthIdentityMigrationReport + boundAuthIdentity *service.AdminBindAuthIdentityInput + boundAuthIdentityFor int64 createdAccounts []*service.CreateAccountInput createdProxies []*service.CreateProxyInput updatedProxyIDs []int64 @@ -201,6 +203,52 @@ func (s *stubAdminService) GetAuthIdentityMigrationReportSummary(ctx context.Con return summary, nil } +func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) { + s.boundAuthIdentityFor = userID + copied := input + if input.Metadata != nil { + copied.Metadata = map[string]any{} + for key, value := range input.Metadata { + copied.Metadata[key] = value + } + } + if input.Channel != nil { + channel := *input.Channel + if input.Channel.Metadata != nil { + channel.Metadata = map[string]any{} + for key, value := range input.Channel.Metadata { + channel.Metadata[key] = value + } + } + copied.Channel = &channel + } + s.boundAuthIdentity = &copied + + now := time.Now().UTC() + result := &service.AdminBoundAuthIdentity{ + UserID: userID, + ProviderType: input.ProviderType, + ProviderKey: input.ProviderKey, + ProviderSubject: input.ProviderSubject, + VerifiedAt: &now, + Issuer: input.Issuer, + Metadata: input.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + if input.Channel != nil { + result.Channel = &service.AdminBoundAuthIdentityChannel{ + Channel: input.Channel.Channel, + ChannelAppID: input.Channel.ChannelAppID, + ChannelSubject: input.Channel.ChannelSubject, + Metadata: input.Channel.Metadata, + CreatedAt: now, + UpdatedAt: now, + } + } + return result, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index ee3fbb1e..321214af 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -66,6 +66,22 @@ type UpdateBalanceRequest struct { Notes string `json:"notes"` } +type BindUserAuthIdentityRequest struct { + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + Issuer *string `json:"issuer"` + Metadata map[string]any `json:"metadata"` + Channel *BindUserAuthIdentityChannelRequest `json:"channel"` +} + +type BindUserAuthIdentityChannelRequest struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` +} + // List handles listing all users with pagination // GET /api/v1/admin/users // Query params: @@ -197,6 +213,45 @@ func (h *UserHandler) ListAuthIdentityMigrationReports(c *gin.Context) { response.Paginated(c, reports, total, page, pageSize) } +// BindAuthIdentity manually binds a canonical auth identity to a user. +// POST /api/v1/admin/users/:id/auth-identities +func (h *UserHandler) BindAuthIdentity(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + var req BindUserAuthIdentityRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + input := service.AdminBindAuthIdentityInput{ + ProviderType: req.ProviderType, + ProviderKey: req.ProviderKey, + ProviderSubject: req.ProviderSubject, + Issuer: req.Issuer, + Metadata: req.Metadata, + } + if req.Channel != nil { + input.Channel = &service.AdminBindAuthIdentityChannelInput{ + Channel: req.Channel.Channel, + ChannelAppID: req.Channel.ChannelAppID, + ChannelSubject: req.Channel.ChannelSubject, + Metadata: req.Channel.Metadata, + } + } + + result, err := h.adminService.BindUserAuthIdentity(c.Request.Context(), userID, input) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 0b5aaf09..c78fba33 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -214,6 +214,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) + users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity) users.POST("", h.Admin.User.Create) users.PUT("/:id", h.Admin.User.Update) users.DELETE("/:id", h.Admin.User.Delete) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 972681a5..9ff26861 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -13,6 +13,8 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -39,6 +41,7 @@ type AdminService interface { GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) + BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -146,6 +149,44 @@ type AuthIdentityMigrationReportSummary struct { ByType map[string]int64 `json:"by_type"` } +type AdminBindAuthIdentityInput struct { + ProviderType string + ProviderKey string + ProviderSubject string + Issuer *string + Metadata map[string]any + Channel *AdminBindAuthIdentityChannelInput +} + +type AdminBindAuthIdentityChannelInput struct { + Channel string + ChannelAppID string + ChannelSubject string + Metadata map[string]any +} + +type AdminBoundAuthIdentity struct { + UserID int64 `json:"user_id"` + ProviderType string `json:"provider_type"` + ProviderKey string `json:"provider_key"` + ProviderSubject string `json:"provider_subject"` + VerifiedAt *time.Time `json:"verified_at,omitempty"` + Issuer *string `json:"issuer,omitempty"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Channel *AdminBoundAuthIdentityChannel `json:"channel,omitempty"` +} + +type AdminBoundAuthIdentityChannel struct { + Channel string `json:"channel"` + ChannelAppID string `json:"channel_app_id"` + ChannelSubject string `json:"channel_subject"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + type CreateGroupInput struct { Name string Description string @@ -895,6 +936,143 @@ ORDER BY report_type ASC`) return summary, nil } +func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") + } + if s == nil || s.entClient == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable") + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + return nil, err + } + + providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType) + providerKey := strings.TrimSpace(input.ProviderKey) + providerSubject := strings.TrimSpace(input.ProviderSubject) + if providerType == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") + } + if providerKey == "" || providerSubject == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") + } + + var issuer *string + if input.Issuer != nil { + trimmed := strings.TrimSpace(*input.Issuer) + if trimmed != "" { + issuer = &trimmed + } + } + + channelInput := normalizeAdminBindChannelInput(input.Channel) + if input.Channel != nil && channelInput == nil { + return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided") + } + + verifiedAt := time.Now().UTC() + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + identity, err := tx.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyEQ(providerKey), + authidentity.ProviderSubjectEQ(providerSubject), + ). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if identity != nil && identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + + if identity == nil { + create := tx.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetProviderSubject(providerSubject). + SetVerifiedAt(verifiedAt) + if issuer != nil { + create = create.SetIssuer(*issuer) + } + if input.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } else { + update := tx.AuthIdentity.UpdateOneID(identity.ID).SetVerifiedAt(verifiedAt) + if issuer != nil { + update = update.SetIssuer(*issuer) + } + if input.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } + + var channel *dbent.AuthIdentityChannel + if channelInput != nil { + channel, err = tx.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyEQ(providerKey), + authidentitychannel.ChannelEQ(channelInput.Channel), + authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), + authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), + ). + WithIdentity(). + Only(ctx) + if err != nil && !dbent.IsNotFound(err) { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if channel != nil && channel.Edges.Identity != nil && channel.Edges.Identity.UserID != userID { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + if channel == nil { + create := tx.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(providerKey). + SetChannel(channelInput.Channel). + SetChannelAppID(channelInput.ChannelAppID). + SetChannelSubject(channelInput.ChannelSubject) + if channelInput.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } else { + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID).SetIdentityID(identity.ID) + if channelInput.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } + } + + if err := tx.Commit(); err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err) + } + return buildAdminBoundAuthIdentity(identity, channel), nil +} + func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { if s == nil || s.entClient == nil { return nil, infraerrors.ServiceUnavailable("ADMIN_SQL_NOT_READY", "admin sql access is not ready") @@ -906,6 +1084,90 @@ func (s *adminServiceImpl) adminSQLDB() (*sql.DB, error) { return driver.DB(), nil } +func normalizeAdminBindChannelInput(input *AdminBindAuthIdentityChannelInput) *AdminBindAuthIdentityChannelInput { + if input == nil { + return nil + } + channel := &AdminBindAuthIdentityChannelInput{ + Channel: strings.TrimSpace(input.Channel), + ChannelAppID: strings.TrimSpace(input.ChannelAppID), + ChannelSubject: strings.TrimSpace(input.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(input.Metadata), + } + if channel.Channel == "" || channel.ChannelAppID == "" || channel.ChannelSubject == "" { + return nil + } + return channel +} + +func normalizeAdminAuthIdentityProviderType(input string) string { + switch strings.ToLower(strings.TrimSpace(input)) { + case "email": + return "email" + case "linuxdo": + return "linuxdo" + case "oidc": + return "oidc" + case "wechat": + return "wechat" + default: + return "" + } +} + +func buildAdminBoundAuthIdentity(identity *dbent.AuthIdentity, channel *dbent.AuthIdentityChannel) *AdminBoundAuthIdentity { + if identity == nil { + return nil + } + result := &AdminBoundAuthIdentity{ + UserID: identity.UserID, + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + VerifiedAt: identity.VerifiedAt, + Issuer: identity.Issuer, + Metadata: cloneAdminAuthIdentityMetadata(identity.Metadata), + CreatedAt: identity.CreatedAt, + UpdatedAt: identity.UpdatedAt, + } + if channel != nil { + result.Channel = &AdminBoundAuthIdentityChannel{ + Channel: strings.TrimSpace(channel.Channel), + ChannelAppID: strings.TrimSpace(channel.ChannelAppID), + ChannelSubject: strings.TrimSpace(channel.ChannelSubject), + Metadata: cloneAdminAuthIdentityMetadata(channel.Metadata), + CreatedAt: channel.CreatedAt, + UpdatedAt: channel.UpdatedAt, + } + } + return result +} + +func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { + if input == nil { + return nil + } + if len(input) == 0 { + return map[string]any{} + } + data, err := json.Marshal(input) + if err != nil { + out := make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + return out + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + out = make(map[string]any, len(input)) + for key, value := range input { + out[key] = value + } + } + return out +} + func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { var ( report AuthIdentityMigrationReport diff --git a/backend/internal/service/admin_service_auth_identity_binding_test.go b/backend/internal/service/admin_service_auth_identity_binding_test.go new file mode 100644 index 00000000..f8ce3935 --- /dev/null +++ b/backend/internal/service/admin_service_auth_identity_binding_test.go @@ -0,0 +1,215 @@ +//go:build unit + +package service + +import ( + "context" + "database/sql" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + "github.com/Wei-Shaw/sub2api/ent/enttest" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAdminServiceAuthIdentityBindingTestClient(t *testing.T) *dbent.Client { + t.Helper() + + db, err := sql.Open("sqlite", "file:admin_service_auth_identity_binding?mode=memory&cache=shared&_fk=1") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestAdminServiceBindUserAuthIdentityCreatesCanonicalAndChannelBinding(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("bind-target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + result, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "wechat", + ProviderKey: "wechat-main", + ProviderSubject: "union-123", + Metadata: map[string]any{"source": "admin-repair"}, + Channel: &AdminBindAuthIdentityChannelInput{ + Channel: "open", + ChannelAppID: "wx-open", + ChannelSubject: "openid-123", + Metadata: map[string]any{"scene": "migration"}, + }, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, user.ID, result.UserID) + require.Equal(t, "wechat", result.ProviderType) + require.Equal(t, "wechat-main", result.ProviderKey) + require.NotNil(t, result.VerifiedAt) + require.NotNil(t, result.Channel) + require.Equal(t, "open", result.Channel.Channel) + + identity, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("wechat"), + authidentity.ProviderKeyEQ("wechat-main"), + authidentity.ProviderSubjectEQ("union-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, user.ID, identity.UserID) + require.Equal(t, "admin-repair", identity.Metadata["source"]) + require.NotNil(t, identity.VerifiedAt) + + channel, err := client.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ("wechat"), + authidentitychannel.ProviderKeyEQ("wechat-main"), + authidentitychannel.ChannelEQ("open"), + authidentitychannel.ChannelAppIDEQ("wx-open"), + authidentitychannel.ChannelSubjectEQ("openid-123"), + ). + Only(ctx) + require.NoError(t, err) + require.Equal(t, identity.ID, channel.IdentityID) + require.Equal(t, "migration", channel.Metadata["scene"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsOtherOwner(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + owner, err := client.User.Create(). + SetEmail("owner@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + target, err := client.User.Create(). + SetEmail("target@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + _, err = client.AuthIdentity.Create(). + SetUserID(owner.ID). + SetProviderType("oidc"). + SetProviderKey("https://issuer.example"). + SetProviderSubject("subject-1"). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: target.ID, Email: target.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, target.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-1", + }) + require.Error(t, err) + require.Equal(t, "AUTH_IDENTITY_OWNERSHIP_CONFLICT", infraerrors.Reason(err)) +} + +func TestAdminServiceBindUserAuthIdentityIsIdempotentForSameUser(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("same-user@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + first, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "first"}, + }) + require.NoError(t, err) + + second, err := svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example", + ProviderSubject: "subject-2", + Metadata: map[string]any{"source": "second"}, + }) + require.NoError(t, err) + require.Equal(t, first.UserID, second.UserID) + require.Equal(t, "second", second.Metadata["source"]) + + identities, err := client.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example"), + authidentity.ProviderSubjectEQ("subject-2"), + ). + All(ctx) + require.NoError(t, err) + require.Len(t, identities, 1) + require.Equal(t, "second", identities[0].Metadata["source"]) +} + +func TestAdminServiceBindUserAuthIdentityRejectsInvalidProviderType(t *testing.T) { + client := newAdminServiceAuthIdentityBindingTestClient(t) + ctx := context.Background() + + user, err := client.User.Create(). + SetEmail("invalid-provider@example.com"). + SetPasswordHash("hash"). + SetRole(RoleUser). + SetStatus(StatusActive). + Save(ctx) + require.NoError(t, err) + + svc := &adminServiceImpl{ + userRepo: &userRepoStub{user: &User{ID: user.ID, Email: user.Email, Status: StatusActive}}, + entClient: client, + } + + _, err = svc.BindUserAuthIdentity(ctx, user.ID, AdminBindAuthIdentityInput{ + ProviderType: "github", + ProviderKey: "github-main", + ProviderSubject: "subject-3", + }) + require.Error(t, err) + require.Equal(t, "INVALID_INPUT", infraerrors.Reason(err)) +} -- GitLab From 724f8e89a1d8a61b9847883dea8573894b6b5c02 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:29:21 +0800 Subject: [PATCH 065/265] feat: resolve auth identity migration reports --- .../admin/admin_basic_handlers_test.go | 14 ++- .../handler/admin/admin_service_stub_test.go | 14 +++ .../internal/handler/admin/user_handler.go | 39 ++++++ backend/internal/server/routes/admin.go | 1 + backend/internal/service/admin_service.go | 111 ++++++++++++++++-- ..._service_identity_migration_report_test.go | 34 +++++- ...h_identity_migration_report_resolution.sql | 11 ++ 7 files changed, 209 insertions(+), 15 deletions(-) create mode 100644 backend/migrations/114_auth_identity_migration_report_resolution.sql diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 57620005..207931d9 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "testing" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -24,6 +25,10 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.GET("/api/v1/admin/users", userHandler.List) router.GET("/api/v1/admin/users/auth-identity-migration-reports/summary", userHandler.GetAuthIdentityMigrationReportSummary) router.GET("/api/v1/admin/users/auth-identity-migration-reports", userHandler.ListAuthIdentityMigrationReports) + router.POST("/api/v1/admin/users/auth-identity-migration-reports/:id/resolve", func(c *gin.Context) { + c.Set(string(servermiddleware.ContextKeyUser), servermiddleware.AuthSubject{UserID: 99}) + userHandler.ResolveAuthIdentityMigrationReport(c) + }) router.GET("/api/v1/admin/users/:id", userHandler.GetByID) router.POST("/api/v1/admin/users/:id/auth-identities", userHandler.BindAuthIdentity) router.POST("/api/v1/admin/users", userHandler.Create) @@ -83,6 +88,13 @@ func TestUserHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + body, _ := json.Marshal(map[string]any{"resolution_note": "resolved by manual bind"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/auth-identity-migration-reports/1/resolve", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil) router.ServeHTTP(rec, req) @@ -99,7 +111,7 @@ func TestUserHandlerEndpoints(t *testing.T) { "channel_subject": "openid-123", }, } - body, _ := json.Marshal(bindBody) + body, _ = json.Marshal(bindBody) rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/auth-identities", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index c8c7a247..8ecadfdf 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -249,6 +249,20 @@ func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int6 return result, nil } +func (s *stubAdminService) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*service.AuthIdentityMigrationReport, error) { + now := time.Now().UTC() + for i := range s.migrationReports { + if s.migrationReports[i].ID != reportID { + continue + } + s.migrationReports[i].ResolvedAt = &now + s.migrationReports[i].ResolvedByUserID = &resolvedByUserID + s.migrationReports[i].ResolutionNote = resolutionNote + return &s.migrationReports[i], nil + } + return nil, nil +} + func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 321214af..e582e322 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -7,6 +7,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -82,6 +83,10 @@ type BindUserAuthIdentityChannelRequest struct { Metadata map[string]any `json:"metadata"` } +type ResolveAuthIdentityMigrationReportRequest struct { + ResolutionNote string `json:"resolution_note"` +} + // List handles listing all users with pagination // GET /api/v1/admin/users // Query params: @@ -252,6 +257,40 @@ func (h *UserHandler) BindAuthIdentity(c *gin.Context) { response.Success(c, result) } +// ResolveAuthIdentityMigrationReport marks a migration report as resolved. +// POST /api/v1/admin/users/auth-identity-migration-reports/:id/resolve +func (h *UserHandler) ResolveAuthIdentityMigrationReport(c *gin.Context) { + reportID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid report ID") + return + } + + subject, ok := servermiddleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Authentication required") + return + } + + var req ResolveAuthIdentityMigrationReportRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + report, err := h.adminService.ResolveAuthIdentityMigrationReport( + c.Request.Context(), + reportID, + subject.UserID, + req.ResolutionNote, + ) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, report) +} + // Create handles creating a new user // POST /api/v1/admin/users func (h *UserHandler) Create(c *gin.Context) { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index c78fba33..e5c0eac1 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -212,6 +212,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { users.GET("/auth-identity-migration-reports/summary", h.Admin.User.GetAuthIdentityMigrationReportSummary) users.GET("/auth-identity-migration-reports", h.Admin.User.ListAuthIdentityMigrationReports) + users.POST("/auth-identity-migration-reports/:id/resolve", h.Admin.User.ResolveAuthIdentityMigrationReport) users.GET("", h.Admin.User.List) users.GET("/:id", h.Admin.User.GetByID) users.POST("/:id/auth-identities", h.Admin.User.BindAuthIdentity) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 9ff26861..3490374e 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -42,6 +42,7 @@ type AdminService interface { ListAuthIdentityMigrationReports(ctx context.Context, reportType string, page, pageSize int) ([]AuthIdentityMigrationReport, int64, error) GetAuthIdentityMigrationReportSummary(ctx context.Context) (*AuthIdentityMigrationReportSummary, error) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) + ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) @@ -137,16 +138,21 @@ type UpdateUserInput struct { } type AuthIdentityMigrationReport struct { - ID int64 `json:"id"` - ReportType string `json:"report_type"` - ReportKey string `json:"report_key"` - Details map[string]any `json:"details"` - CreatedAt time.Time `json:"created_at"` + ID int64 `json:"id"` + ReportType string `json:"report_type"` + ReportKey string `json:"report_key"` + Details map[string]any `json:"details"` + CreatedAt time.Time `json:"created_at"` + ResolvedAt *time.Time `json:"resolved_at,omitempty"` + ResolvedByUserID *int64 `json:"resolved_by_user_id,omitempty"` + ResolutionNote string `json:"resolution_note,omitempty"` } type AuthIdentityMigrationReportSummary struct { - Total int64 `json:"total"` - ByType map[string]int64 `json:"by_type"` + Total int64 `json:"total"` + OpenTotal int64 `json:"open_total"` + ResolvedTotal int64 `json:"resolved_total"` + ByType map[string]int64 `json:"by_type"` } type AdminBindAuthIdentityInput struct { @@ -874,7 +880,7 @@ WHERE ($1 = '' OR report_type = $1)`, } rows, err := db.QueryContext(ctx, ` -SELECT id, report_type, report_key, details, created_at +SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note FROM auth_identity_migration_reports WHERE ($1 = '' OR report_type = $1) ORDER BY created_at DESC, id DESC @@ -909,7 +915,11 @@ func (s *adminServiceImpl) GetAuthIdentityMigrationReportSummary(ctx context.Con } rows, err := db.QueryContext(ctx, ` -SELECT report_type, COUNT(*) +SELECT + report_type, + COUNT(*), + SUM(CASE WHEN resolved_at IS NULL THEN 1 ELSE 0 END), + SUM(CASE WHEN resolved_at IS NOT NULL THEN 1 ELSE 0 END) FROM auth_identity_migration_reports GROUP BY report_type ORDER BY report_type ASC`) @@ -924,11 +934,15 @@ ORDER BY report_type ASC`) for rows.Next() { var reportType string var count int64 - if err := rows.Scan(&reportType, &count); err != nil { + var openCount int64 + var resolvedCount int64 + if err := rows.Scan(&reportType, &count, &openCount, &resolvedCount); err != nil { return nil, err } summary.ByType[reportType] = count summary.Total += count + summary.OpenTotal += openCount + summary.ResolvedTotal += resolvedCount } if err := rows.Err(); err != nil { return nil, err @@ -936,6 +950,56 @@ ORDER BY report_type ASC`) return summary, nil } +func (s *adminServiceImpl) ResolveAuthIdentityMigrationReport(ctx context.Context, reportID, resolvedByUserID int64, resolutionNote string) (*AuthIdentityMigrationReport, error) { + if reportID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "report id must be greater than 0") + } + if resolvedByUserID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "resolved_by_user_id must be greater than 0") + } + + db, err := s.adminSQLDB() + if err != nil { + return nil, err + } + + now := time.Now().UTC() + result, err := db.ExecContext(ctx, ` +UPDATE auth_identity_migration_reports +SET + resolved_at = COALESCE(resolved_at, $2), + resolved_by_user_id = COALESCE(resolved_by_user_id, $3), + resolution_note = $4 +WHERE id = $1`, + reportID, + now, + resolvedByUserID, + strings.TrimSpace(resolutionNote), + ) + if err != nil { + return nil, err + } + affected, err := result.RowsAffected() + if err != nil { + return nil, err + } + if affected == 0 { + return nil, infraerrors.NotFound("AUTH_IDENTITY_MIGRATION_REPORT_NOT_FOUND", "auth identity migration report not found") + } + + row := db.QueryRowContext(ctx, ` +SELECT id, report_type, report_key, details, created_at, resolved_at, resolved_by_user_id, resolution_note +FROM auth_identity_migration_reports +WHERE id = $1`, + reportID, + ) + report, err := scanAuthIdentityMigrationReport(row) + if err != nil { + return nil, err + } + return &report, nil +} + func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { if userID <= 0 { return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") @@ -1170,10 +1234,22 @@ func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error }) (AuthIdentityMigrationReport, error) { var ( - report AuthIdentityMigrationReport - details []byte + report AuthIdentityMigrationReport + details []byte + resolvedAt sql.NullTime + resolvedByUserID sql.NullInt64 + resolutionNote sql.NullString ) - if err := scanner.Scan(&report.ID, &report.ReportType, &report.ReportKey, &details, &report.CreatedAt); err != nil { + if err := scanner.Scan( + &report.ID, + &report.ReportType, + &report.ReportKey, + &details, + &report.CreatedAt, + &resolvedAt, + &resolvedByUserID, + &resolutionNote, + ); err != nil { return AuthIdentityMigrationReport{}, err } report.Details = map[string]any{} @@ -1182,6 +1258,15 @@ func scanAuthIdentityMigrationReport(scanner interface{ Scan(dest ...any) error return AuthIdentityMigrationReport{}, err } } + if resolvedAt.Valid { + report.ResolvedAt = &resolvedAt.Time + } + if resolvedByUserID.Valid { + report.ResolvedByUserID = &resolvedByUserID.Int64 + } + if resolutionNote.Valid { + report.ResolutionNote = resolutionNote.String + } return report, nil } diff --git a/backend/internal/service/admin_service_identity_migration_report_test.go b/backend/internal/service/admin_service_identity_migration_report_test.go index 75ca3e5a..6975604b 100644 --- a/backend/internal/service/admin_service_identity_migration_report_test.go +++ b/backend/internal/service/admin_service_identity_migration_report_test.go @@ -30,7 +30,10 @@ func newAdminServiceMigrationReportTestClient(t *testing.T) *dbent.Client { report_type TEXT NOT NULL, report_key TEXT NOT NULL, details TEXT NOT NULL DEFAULT '{}', - created_at DATETIME NOT NULL + created_at DATETIME NOT NULL, + resolved_at DATETIME NULL, + resolved_by_user_id INTEGER NULL, + resolution_note TEXT NOT NULL DEFAULT '' )`) require.NoError(t, err) @@ -87,6 +90,35 @@ VALUES summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background()) require.NoError(t, err) require.Equal(t, int64(3), summary.Total) + require.Equal(t, int64(3), summary.OpenTotal) + require.Zero(t, summary.ResolvedTotal) require.Equal(t, int64(1), summary.ByType["oidc_synthetic_email_requires_manual_recovery"]) require.Equal(t, int64(2), summary.ByType["wechat_provider_key_conflict"]) } + +func TestAdminServiceResolveAuthIdentityMigrationReport(t *testing.T) { + client := newAdminServiceMigrationReportTestClient(t) + driver, ok := client.Driver().(*entsql.Driver) + require.True(t, ok) + + now := time.Now().UTC() + _, err := driver.DB().ExecContext(context.Background(), ` +INSERT INTO auth_identity_migration_reports (report_type, report_key, details, created_at) +VALUES ($1, $2, $3, $4)`, + "oidc_synthetic_email_requires_manual_recovery", "u-1", `{"user_id":1}`, now, + ) + require.NoError(t, err) + + svc := &adminServiceImpl{entClient: client} + report, err := svc.ResolveAuthIdentityMigrationReport(context.Background(), 1, 99, "resolved by admin binding") + require.NoError(t, err) + require.NotNil(t, report.ResolvedAt) + require.NotNil(t, report.ResolvedByUserID) + require.Equal(t, int64(99), *report.ResolvedByUserID) + require.Equal(t, "resolved by admin binding", report.ResolutionNote) + + summary, err := svc.GetAuthIdentityMigrationReportSummary(context.Background()) + require.NoError(t, err) + require.Zero(t, summary.OpenTotal) + require.Equal(t, int64(1), summary.ResolvedTotal) +} diff --git a/backend/migrations/114_auth_identity_migration_report_resolution.sql b/backend/migrations/114_auth_identity_migration_report_resolution.sql new file mode 100644 index 00000000..f84bf822 --- /dev/null +++ b/backend/migrations/114_auth_identity_migration_report_resolution.sql @@ -0,0 +1,11 @@ +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_at TIMESTAMPTZ NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolved_by_user_id BIGINT NULL; + +ALTER TABLE auth_identity_migration_reports + ADD COLUMN IF NOT EXISTS resolution_note TEXT NOT NULL DEFAULT ''; + +CREATE INDEX IF NOT EXISTS idx_auth_identity_migration_reports_resolved_at + ON auth_identity_migration_reports (resolved_at); -- GitLab From bffcc2042eeeb61eb0e5c9fc0c7cd4d40ff89f96 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Mon, 20 Apr 2026 22:37:25 +0800 Subject: [PATCH 066/265] fix: complete oidc pending auth callback flows --- frontend/src/views/auth/OidcCallbackView.vue | 420 ++++++++++++++---- .../auth/__tests__/OidcCallbackView.spec.ts | 187 +++++--- 2 files changed, 457 insertions(+), 150 deletions(-) diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue index deb40b20..de3f2e40 100644 --- a/frontend/src/views/auth/OidcCallbackView.vue +++ b/frontend/src/views/auth/OidcCallbackView.vue @@ -18,9 +18,10 @@
@@ -108,46 +109,131 @@ - @@ -281,6 +283,7 @@ import SubscriptionPlanCard from '@/components/payment/SubscriptionPlanCard.vue' import PaymentStatusPanel from '@/components/payment/PaymentStatusPanel.vue' import Icon from '@/components/icons/Icon.vue' import type { PaymentMethodOption } from '@/components/payment/PaymentMethodSelector.vue' +import { describePaymentScenarioError } from './paymentUx' const { t } = useI18n() const route = useRoute() @@ -301,6 +304,7 @@ function getDaysRemaining(expiresAt: string): number { const loading = ref(true) const submitting = ref(false) const errorMessage = ref('') +const errorHintMessage = ref('') const activeTab = ref<'recharge' | 'subscription'>('recharge') const amount = ref(null) const selectedMethod = ref('') @@ -619,6 +623,7 @@ async function confirmSubscribe() { async function createOrder(orderAmount: number, orderType: OrderType, planId?: number, options: CreateOrderOptions = {}) { submitting.value = true errorMessage.value = '' + errorHintMessage.value = '' try { const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value const payload = buildCreateOrderPayload({ @@ -668,8 +673,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } if (decision.kind === 'unhandled') { - errorMessage.value = t('payment.result.failed') - appStore.showError(errorMessage.value) + applyScenarioError({ reason: 'UNHANDLED_PAYMENT_SCENARIO' }, visibleMethod) return } @@ -691,7 +695,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n if (errMsg.includes('cancel')) { appStore.showInfo(t('payment.qr.cancelled')) } else if (errMsg && !errMsg.includes('ok')) { - appStore.showError(t('payment.result.failed')) + applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod) } return } @@ -707,10 +711,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n if (apiErr.reason === 'TOO_MANY_PENDING') { const metadata = apiErr.metadata as Record | undefined errorMessage.value = t('payment.errors.tooManyPending', { max: metadata?.max || '' }) + errorHintMessage.value = '' } else if (apiErr.reason === 'CANCEL_RATE_LIMITED') { errorMessage.value = t('payment.errors.cancelRateLimited') + errorHintMessage.value = '' } else { - errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed')) + applyScenarioError(err, normalizeVisibleMethod(options.paymentType || selectedMethod.value) || selectedMethod.value) + if (!errorMessage.value) { + errorMessage.value = extractApiErrorMessage(err, t('payment.result.failed')) + errorHintMessage.value = '' + } } appStore.showError(errorMessage.value) } finally { @@ -718,6 +728,21 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } } +function applyScenarioError(err: unknown, paymentMethod: string) { + const descriptor = describePaymentScenarioError(err, { + paymentMethod, + isMobile: isMobileDevice(), + isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent), + }) + if (!descriptor) { + errorMessage.value = '' + errorHintMessage.value = '' + return + } + errorMessage.value = t(descriptor.messageKey) + errorHintMessage.value = descriptor.hintKey ? t(descriptor.hintKey) : '' +} + async function resumeWechatPaymentFromQuery() { const openid = readRouteQueryValue(route.query.openid) if (readRouteQueryValue(route.query.wechat_resume) !== '1' || !openid) { diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts index d23a60d9..b1caa526 100644 --- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts @@ -155,4 +155,29 @@ describe('PaymentResultView', () => { expect(wrapper.text()).toContain('payment.result.success') expect(verifyOrderPublic).not.toHaveBeenCalled() }) + + it('normalizes aliased payment methods before rendering the label', async () => { + routeState.query = { + resume_token: 'resume-88', + } + resolveOrderPublicByResumeToken.mockResolvedValueOnce({ + data: { + ...orderFactory('PAID'), + payment_type: 'alipay_direct', + }, + }) + + const wrapper = mount(PaymentResultView, { + global: { + stubs: { + OrderStatusBadge: true, + }, + }, + }) + + await flushPromises() + + expect(wrapper.text()).toContain('payment.methods.alipay') + expect(wrapper.text()).not.toContain('payment.methods.alipay_direct') + }) }) diff --git a/frontend/src/views/user/__tests__/paymentUx.spec.ts b/frontend/src/views/user/__tests__/paymentUx.spec.ts new file mode 100644 index 00000000..6cf105f2 --- /dev/null +++ b/frontend/src/views/user/__tests__/paymentUx.spec.ts @@ -0,0 +1,49 @@ +import { describe, expect, it } from 'vitest' +import { + describePaymentScenarioError, + normalizePaymentMethodForDisplay, +} from '../paymentUx' + +describe('normalizePaymentMethodForDisplay', () => { + it('collapses visible payment aliases to canonical method ids', () => { + expect(normalizePaymentMethodForDisplay(' alipay_direct ')).toBe('alipay') + expect(normalizePaymentMethodForDisplay('wxpay_direct')).toBe('wxpay') + expect(normalizePaymentMethodForDisplay('wechat_pay')).toBe('wxpay') + }) + + it('leaves non-aliased methods untouched', () => { + expect(normalizePaymentMethodForDisplay('stripe')).toBe('stripe') + }) +}) + +describe('describePaymentScenarioError', () => { + it('maps WeChat H5 authorization errors to explicit in-app guidance', () => { + expect(describePaymentScenarioError( + { reason: 'WECHAT_H5_NOT_AUTHORIZED' }, + { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: false }, + )).toEqual({ + messageKey: 'payment.errors.wechatH5NotAuthorized', + hintKey: 'payment.errors.wechatOpenInWeChatHint', + }) + }) + + it('maps missing WeixinJSBridge to a JSAPI-specific prompt', () => { + expect(describePaymentScenarioError( + new Error('WeixinJSBridge is unavailable'), + { paymentMethod: 'wxpay', isMobile: true, isWechatBrowser: true }, + )).toEqual({ + messageKey: 'payment.errors.wechatJsapiUnavailable', + hintKey: 'payment.errors.wechatOpenInWeChatHint', + }) + }) + + it('maps generic desktop Alipay failures to QR guidance', () => { + expect(describePaymentScenarioError( + { reason: 'PAYMENT_GATEWAY_ERROR' }, + { paymentMethod: 'alipay', isMobile: false, isWechatBrowser: false }, + )).toEqual({ + messageKey: 'payment.errors.alipayDesktopUnavailable', + hintKey: 'payment.errors.alipayDesktopQrHint', + }) + }) +}) diff --git a/frontend/src/views/user/paymentUx.ts b/frontend/src/views/user/paymentUx.ts new file mode 100644 index 00000000..443529a7 --- /dev/null +++ b/frontend/src/views/user/paymentUx.ts @@ -0,0 +1,105 @@ +import { normalizeVisibleMethod } from '@/components/payment/paymentFlow' +import { extractApiErrorCode } from '@/utils/apiError' + +const DISPLAY_METHOD_ALIASES: Record = { + wechat: 'wxpay', + wechat_pay: 'wxpay', +} + +export interface PaymentScenarioContext { + paymentMethod: string + isMobile: boolean + isWechatBrowser: boolean +} + +export interface PaymentScenarioErrorDescriptor { + messageKey: string + hintKey?: string +} + +export function normalizePaymentMethodForDisplay(paymentType: string): string { + const trimmed = paymentType.trim().toLowerCase() + const visibleMethod = normalizeVisibleMethod(trimmed) + if (visibleMethod) return visibleMethod + return DISPLAY_METHOD_ALIASES[trimmed] ?? trimmed +} + +export function paymentMethodI18nKey(paymentType: string): string { + return `payment.methods.${normalizePaymentMethodForDisplay(paymentType)}` +} + +function defaultWechatHint(context: PaymentScenarioContext): string { + if (!context.isMobile) return 'payment.errors.wechatScanOnDesktopHint' + return 'payment.errors.wechatOpenInWeChatHint' +} + +function defaultAlipayHint(context: PaymentScenarioContext): string { + if (context.isMobile) return 'payment.errors.alipayMobileOpenHint' + return 'payment.errors.alipayDesktopQrHint' +} + +export function describePaymentScenarioError( + error: unknown, + context: PaymentScenarioContext, +): PaymentScenarioErrorDescriptor | null { + const method = normalizePaymentMethodForDisplay(context.paymentMethod) + const code = extractApiErrorCode(error) + const message = error instanceof Error + ? error.message + : (typeof error === 'object' && error && 'message' in error && typeof error.message === 'string' + ? error.message + : String(error || '')) + const normalizedMessage = message.toLowerCase() + + if (method === 'wxpay') { + if (code === 'WECHAT_H5_NOT_AUTHORIZED') { + return { + messageKey: 'payment.errors.wechatH5NotAuthorized', + hintKey: defaultWechatHint(context), + } + } + if (code === 'WECHAT_PAYMENT_MP_NOT_CONFIGURED') { + return { + messageKey: 'payment.errors.wechatPaymentMpNotConfigured', + hintKey: context.isWechatBrowser + ? 'payment.errors.wechatSwitchBrowserHint' + : defaultWechatHint(context), + } + } + if (code === 'NO_AVAILABLE_INSTANCE') { + return { + messageKey: 'payment.errors.wechatUnavailable', + hintKey: defaultWechatHint(context), + } + } + if (code === 'WECHAT_JSAPI_FAILED' || normalizedMessage.includes('get_brand_wcpay_request:fail')) { + return { + messageKey: 'payment.errors.wechatJsapiFailed', + hintKey: defaultWechatHint(context), + } + } + if (normalizedMessage.includes('weixinjsbridge is unavailable')) { + return { + messageKey: 'payment.errors.wechatJsapiUnavailable', + hintKey: 'payment.errors.wechatOpenInWeChatHint', + } + } + if (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO') { + return { + messageKey: 'payment.errors.wechatUnavailable', + hintKey: defaultWechatHint(context), + } + } + } + + if (method === 'alipay' && (code === 'PAYMENT_GATEWAY_ERROR' || code === 'UNHANDLED_PAYMENT_SCENARIO')) { + return { + messageKey: context.isMobile + ? 'payment.errors.alipayMobileUnavailable' + : 'payment.errors.alipayDesktopUnavailable', + hintKey: defaultAlipayHint(context), + } + } + + return null +} -- GitLab From 9e84e2fd2bed765d38ca4a438fbbf1739950ae32 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 00:05:17 +0800 Subject: [PATCH 076/265] fix: persist admin payment visibility and scheduler settings --- .../internal/handler/admin/setting_handler.go | 64 ++++++++++++++++ ...tting_handler_auth_source_defaults_test.go | 74 +++++++++++++++++++ backend/internal/handler/dto/settings.go | 9 +++ backend/internal/server/api_contract_test.go | 63 +++++++++++++++- .../service/admin_service_apikey_test.go | 6 ++ backend/internal/service/setting_service.go | 54 +++++++++++++- .../service/setting_service_update_test.go | 31 ++++++++ backend/internal/service/settings_view.go | 9 +++ backend/internal/service/user_service_test.go | 6 ++ 9 files changed, 311 insertions(+), 5 deletions(-) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index fe5c7928..e5681208 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -180,6 +180,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableCCHSigning: settings.EnableCCHSigning, WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, + PaymentVisibleMethodAlipaySource: settings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: settings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: settings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: settings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: settings.OpenAIAdvancedSchedulerEnabled, BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, @@ -338,6 +343,15 @@ type UpdateSettingsRequest struct { EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableCCHSigning *bool `json:"enable_cch_signing"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"` + // Balance low notification BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` @@ -935,6 +949,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + PaymentVisibleMethodAlipaySource: func() string { + if req.PaymentVisibleMethodAlipaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodAlipaySource) + } + return previousSettings.PaymentVisibleMethodAlipaySource + }(), + PaymentVisibleMethodWxpaySource: func() string { + if req.PaymentVisibleMethodWxpaySource != nil { + return strings.TrimSpace(*req.PaymentVisibleMethodWxpaySource) + } + return previousSettings.PaymentVisibleMethodWxpaySource + }(), + PaymentVisibleMethodAlipayEnabled: func() bool { + if req.PaymentVisibleMethodAlipayEnabled != nil { + return *req.PaymentVisibleMethodAlipayEnabled + } + return previousSettings.PaymentVisibleMethodAlipayEnabled + }(), + PaymentVisibleMethodWxpayEnabled: func() bool { + if req.PaymentVisibleMethodWxpayEnabled != nil { + return *req.PaymentVisibleMethodWxpayEnabled + } + return previousSettings.PaymentVisibleMethodWxpayEnabled + }(), + OpenAIAdvancedSchedulerEnabled: func() bool { + if req.OpenAIAdvancedSchedulerEnabled != nil { + return *req.OpenAIAdvancedSchedulerEnabled + } + return previousSettings.OpenAIAdvancedSchedulerEnabled + }(), BalanceLowNotifyEnabled: func() bool { if req.BalanceLowNotifyEnabled != nil { return *req.BalanceLowNotifyEnabled @@ -1153,6 +1197,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, + PaymentVisibleMethodAlipaySource: updatedSettings.PaymentVisibleMethodAlipaySource, + PaymentVisibleMethodWxpaySource: updatedSettings.PaymentVisibleMethodWxpaySource, + PaymentVisibleMethodAlipayEnabled: updatedSettings.PaymentVisibleMethodAlipayEnabled, + PaymentVisibleMethodWxpayEnabled: updatedSettings.PaymentVisibleMethodWxpayEnabled, + OpenAIAdvancedSchedulerEnabled: updatedSettings.OpenAIAdvancedSchedulerEnabled, BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, @@ -1455,6 +1504,21 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableCCHSigning != after.EnableCCHSigning { changed = append(changed, "enable_cch_signing") } + if before.PaymentVisibleMethodAlipaySource != after.PaymentVisibleMethodAlipaySource { + changed = append(changed, "payment_visible_method_alipay_source") + } + if before.PaymentVisibleMethodWxpaySource != after.PaymentVisibleMethodWxpaySource { + changed = append(changed, "payment_visible_method_wxpay_source") + } + if before.PaymentVisibleMethodAlipayEnabled != after.PaymentVisibleMethodAlipayEnabled { + changed = append(changed, "payment_visible_method_alipay_enabled") + } + if before.PaymentVisibleMethodWxpayEnabled != after.PaymentVisibleMethodWxpayEnabled { + changed = append(changed, "payment_visible_method_wxpay_enabled") + } + if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled { + changed = append(changed, "openai_advanced_scheduler_enabled") + } // Balance & quota notification if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled { changed = append(changed, "balance_low_notify_enabled") diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index b26fa447..bf51fc68 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -147,3 +147,77 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes require.Equal(t, float64(8), data["auth_source_default_email_concurrency"]) require.Equal(t, true, data["force_email_on_third_party_signup"]) } + +func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "easypay", + "payment_visible_method_wxpay_source": "wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, repo.values[service.SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, repo.values[service.SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.values[service.SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.values[service.SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.values["openai_advanced_scheduler_enabled"]) + + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, service.VisibleMethodSourceEasyPayAlipay, data["payment_visible_method_alipay_source"]) + require.Equal(t, service.VisibleMethodSourceOfficialWechat, data["payment_visible_method_wxpay_source"]) + require.Equal(t, true, data["payment_visible_method_alipay_enabled"]) + require.Equal(t, false, data["payment_visible_method_wxpay_enabled"]) + require.Equal(t, true, data["openai_advanced_scheduler_enabled"]) +} + +func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := &settingHandlerRepoStub{ + values: map[string]string{ + service.SettingKeyPromoCodeEnabled: "true", + }, + } + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + + body := map[string]any{ + "promo_code_enabled": true, + "payment_visible_method_alipay_source": "bogus", + } + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.NotContains(t, repo.values, service.SettingPaymentVisibleMethodAlipaySource) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 637e317b..cc3f8496 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -127,6 +127,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` + // Payment visible method routing + PaymentVisibleMethodAlipaySource string `json:"payment_visible_method_alipay_source"` + PaymentVisibleMethodWxpaySource string `json:"payment_visible_method_wxpay_source"` + PaymentVisibleMethodAlipayEnabled bool `json:"payment_visible_method_alipay_enabled"` + PaymentVisibleMethodWxpayEnabled bool `json:"payment_visible_method_wxpay_enabled"` + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool `json:"openai_advanced_scheduler_enabled"` + // Payment configuration PaymentEnabled bool `json:"payment_enabled"` PaymentMinAmount float64 `json:"payment_min_amount"` diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index e903898f..533f2dac 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -500,10 +500,15 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyTableDefaultPageSize: "20", service.SettingKeyTablePageSizeOptions: "[10,20,50,100]", - service.SettingKeyOpsMonitoringEnabled: "false", - service.SettingKeyOpsRealtimeMonitoringEnabled: "true", - service.SettingKeyOpsQueryModeDefault: "auto", - service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingKeyOpsMonitoringEnabled: "false", + service.SettingKeyOpsRealtimeMonitoringEnabled: "true", + service.SettingKeyOpsQueryModeDefault: "auto", + service.SettingKeyOpsMetricsIntervalSeconds: "60", + service.SettingPaymentVisibleMethodAlipaySource: service.VisibleMethodSourceEasyPayAlipay, + service.SettingPaymentVisibleMethodWxpaySource: service.VisibleMethodSourceOfficialWechat, + service.SettingPaymentVisibleMethodAlipayEnabled: "true", + service.SettingPaymentVisibleMethodWxpayEnabled: "false", + "openai_advanced_scheduler_enabled": "true", }) }, method: http.MethodGet, @@ -567,6 +572,27 @@ func TestAPIContracts(t *testing.T) { "api_base_url": "https://api.example.com", "contact_info": "support", "doc_url": "https://docs.example.com", + "auth_source_default_email_balance": 0, + "auth_source_default_email_concurrency": 5, + "auth_source_default_email_subscriptions": [], + "auth_source_default_email_grant_on_signup": false, + "auth_source_default_email_grant_on_first_bind": false, + "auth_source_default_linuxdo_balance": 0, + "auth_source_default_linuxdo_concurrency": 5, + "auth_source_default_linuxdo_subscriptions": [], + "auth_source_default_linuxdo_grant_on_signup": false, + "auth_source_default_linuxdo_grant_on_first_bind": false, + "auth_source_default_oidc_balance": 0, + "auth_source_default_oidc_concurrency": 5, + "auth_source_default_oidc_subscriptions": [], + "auth_source_default_oidc_grant_on_signup": false, + "auth_source_default_oidc_grant_on_first_bind": false, + "auth_source_default_wechat_balance": 0, + "auth_source_default_wechat_concurrency": 5, + "auth_source_default_wechat_subscriptions": [], + "auth_source_default_wechat_grant_on_signup": false, + "auth_source_default_wechat_grant_on_first_bind": false, + "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, "default_subscriptions": [], @@ -592,6 +618,11 @@ func TestAPIContracts(t *testing.T) { "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "web_search_emulation_enabled": false, + "payment_visible_method_alipay_source": "easypay_alipay", + "payment_visible_method_wxpay_source": "official_wxpay", + "payment_visible_method_alipay_enabled": true, + "payment_visible_method_wxpay_enabled": false, + "openai_advanced_scheduler_enabled": true, "custom_menu_items": [], "custom_endpoints": [], "payment_enabled": false, @@ -858,6 +889,18 @@ func (r *stubUserRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } +func (r *stubUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + return nil, nil +} + +func (r *stubUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + return errors.New("not implemented") +} + func (r *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } @@ -894,6 +937,18 @@ func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64 return errors.New("not implemented") } +func (r *stubUserRepo) ListUserAuthIdentities(ctx context.Context, userID int64) ([]service.UserAuthIdentityRecord, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserRepo) GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *stubUserRepo) GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) { + return nil, nil +} + func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { return errors.New("not implemented") } diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 487fb5f1..e2eae0b4 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -82,6 +82,12 @@ func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { func (s *userRepoStubForGroupUpdate) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + panic("unexpected") +} +func (s *userRepoStubForGroupUpdate) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + panic("unexpected") +} func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { panic("unexpected") } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index a2644fcd..02a64c1c 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -566,6 +566,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet normalizedWhitelist = []string{} } settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist + alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled) + if err != nil { + return err + } + wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled) + if err != nil { + return err + } + settings.PaymentVisibleMethodAlipaySource = alipaySource + settings.PaymentVisibleMethodWxpaySource = wxpaySource updates := make(map[string]string) @@ -701,6 +711,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) + updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource + updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource + updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) + updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled) + updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled) // Balance low notification updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) @@ -730,6 +745,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet cchSigning: settings.EnableCCHSigning, expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) + openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: settings.OpenAIAdvancedSchedulerEnabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) if s.onUpdate != nil { s.onUpdate() // Invalidate cache after settings update } @@ -1137,7 +1157,12 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyMaxClaudeCodeVersion: "", // 分组隔离(默认不允许未分组 Key 调度) - SettingKeyAllowUngroupedKeyScheduling: "false", + SettingKeyAllowUngroupedKeyScheduling: "false", + SettingPaymentVisibleMethodAlipaySource: "", + SettingPaymentVisibleMethodWxpaySource: "", + SettingPaymentVisibleMethodAlipayEnabled: "false", + SettingPaymentVisibleMethodWxpayEnabled: "false", + openAIAdvancedSchedulerSettingKey: "false", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -1429,6 +1454,11 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0 } } + result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource]) + result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource]) + result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true" + result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true" + result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true" // Balance low notification result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" @@ -1458,6 +1488,28 @@ func isFalseSettingValue(value string) bool { } } +func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (string, error) { + source = strings.TrimSpace(source) + if source == "" { + if enabled { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source is required when the visible method is enabled", method), + ) + } + return "", nil + } + + normalized := NormalizeVisibleMethodSource(method, source) + if normalized == "" { + return "", infraerrors.BadRequest( + "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", + fmt.Sprintf("%s source must be one of the supported payment providers", method), + ) + } + return normalized, nil +} + func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { raw = strings.TrimSpace(raw) if raw == "" { diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index e62218b4..9dc0ca59 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -223,3 +223,34 @@ func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize]) require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions]) } + +func TestSettingService_UpdateSettings_PaymentVisibleMethodsAndAdvancedScheduler(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "alipay", + PaymentVisibleMethodWxpaySource: "easypay", + PaymentVisibleMethodAlipayEnabled: true, + PaymentVisibleMethodWxpayEnabled: false, + OpenAIAdvancedSchedulerEnabled: true, + }) + require.NoError(t, err) + require.Equal(t, VisibleMethodSourceOfficialAlipay, repo.updates[SettingPaymentVisibleMethodAlipaySource]) + require.Equal(t, VisibleMethodSourceEasyPayWechat, repo.updates[SettingPaymentVisibleMethodWxpaySource]) + require.Equal(t, "true", repo.updates[SettingPaymentVisibleMethodAlipayEnabled]) + require.Equal(t, "false", repo.updates[SettingPaymentVisibleMethodWxpayEnabled]) + require.Equal(t, "true", repo.updates[openAIAdvancedSchedulerSettingKey]) +} + +func TestSettingService_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + PaymentVisibleMethodAlipaySource: "not-a-provider", + }) + require.Error(t, err) + require.Equal(t, "INVALID_PAYMENT_VISIBLE_METHOD_SOURCE", infraerrors.Reason(err)) + require.Nil(t, repo.updates) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 72db4e31..9bd461f9 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -110,6 +110,15 @@ type SystemSettings struct { // Web Search Emulation WebSearchEmulationEnabled bool // 是否启用 web search 模拟 + // Payment visible method routing + PaymentVisibleMethodAlipaySource string + PaymentVisibleMethodWxpaySource string + PaymentVisibleMethodAlipayEnabled bool + PaymentVisibleMethodWxpayEnabled bool + + // OpenAI account scheduling + OpenAIAdvancedSchedulerEnabled bool + // Balance low notification BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64 diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 89f0362e..e13fb95d 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -103,6 +103,12 @@ func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) er func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { return nil, nil } +func (m *mockUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (m *mockUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } -- GitLab From 4f6966d7b3979809da9238498bf403ced840d20b Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 00:05:42 +0800 Subject: [PATCH 077/265] frontend: route wechat oauth entry by public settings --- frontend/src/api/auth.ts | 55 ++++++ .../components/auth/WechatOAuthSection.vue | 62 ++++++- .../auth/__tests__/WechatOAuthSection.spec.ts | 170 ++++++++++++++++-- 3 files changed, 261 insertions(+), 26 deletions(-) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 6a2feb87..6c877d76 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -349,6 +349,61 @@ export async function getPublicSettings(): Promise { return data } +export type WeChatOAuthMode = 'open' | 'mp' +export type WeChatOAuthUnavailableReason = + | 'not_configured' + | 'external_browser_required' + | 'wechat_browser_required' + +export interface ResolvedWeChatOAuthStart { + mode: WeChatOAuthMode | null + openEnabled: boolean + mpEnabled: boolean + isWeChatBrowser: boolean + unavailableReason: WeChatOAuthUnavailableReason | null +} + +type WeChatOAuthPublicSettings = { + wechat_oauth_enabled?: boolean + wechat_oauth_open_enabled?: boolean + wechat_oauth_mp_enabled?: boolean +} + +export function resolveWeChatOAuthStart( + settings: WeChatOAuthPublicSettings | null | undefined, + userAgent?: string +): ResolvedWeChatOAuthStart { + const normalizedUserAgent = (userAgent + ?? (typeof navigator !== 'undefined' ? navigator.userAgent : '') + ?? '').trim() + const isWeChatBrowser = /MicroMessenger/i.test(normalizedUserAgent) + const legacyEnabled = settings?.wechat_oauth_enabled ?? false + const openEnabled = typeof settings?.wechat_oauth_open_enabled === 'boolean' + ? settings.wechat_oauth_open_enabled + : legacyEnabled + const mpEnabled = typeof settings?.wechat_oauth_mp_enabled === 'boolean' + ? settings.wechat_oauth_mp_enabled + : legacyEnabled + + if (isWeChatBrowser) { + if (mpEnabled) { + return { mode: 'mp', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null } + } + if (openEnabled) { + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'external_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } + } + + if (openEnabled) { + return { mode: 'open', openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: null } + } + if (mpEnabled) { + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'wechat_browser_required' } + } + return { mode: null, openEnabled, mpEnabled, isWeChatBrowser, unavailableReason: 'not_configured' } +} + /** * Send verification code to email * @param request - Email and optional Turnstile token diff --git a/frontend/src/components/auth/WechatOAuthSection.vue b/frontend/src/components/auth/WechatOAuthSection.vue index 94e20222..01bbd180 100644 --- a/frontend/src/components/auth/WechatOAuthSection.vue +++ b/frontend/src/components/auth/WechatOAuthSection.vue @@ -1,6 +1,6 @@ diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts new file mode 100644 index 00000000..5e6b0ae0 --- /dev/null +++ b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts @@ -0,0 +1,243 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' +import { defineComponent, h } from 'vue' + +import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue' + +const { getAuthIdentityMigrationReportSummary, listAuthIdentityMigrationReports, resolveAuthIdentityMigrationReport } = vi.hoisted(() => ({ + getAuthIdentityMigrationReportSummary: vi.fn(), + listAuthIdentityMigrationReports: vi.fn(), + resolveAuthIdentityMigrationReport: vi.fn(), +})) + +const { showError, showSuccess } = vi.hoisted(() => ({ + showError: vi.fn(), + showSuccess: vi.fn(), +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + users: { + getAuthIdentityMigrationReportSummary, + listAuthIdentityMigrationReports, + resolveAuthIdentityMigrationReport, + }, + }, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError, + showSuccess, + }), +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + locale: { value: 'en' }, + t: (key: string) => key, + }), + } +}) + +vi.mock('@/utils/format', () => ({ + formatDateTime: (value: string | null | undefined) => value ?? '', +})) + +const sampleReport = { + id: 1, + report_type: 'oidc_synthetic_email_requires_manual_recovery', + report_key: 'legacy@example.invalid', + details: { + user_id: 42, + legacy_email: 'legacy@example.invalid', + provider_key: 'https://issuer.example', + provider_subject: 'subject-123', + }, + created_at: '2026-04-20T01:02:03Z', + resolved_at: null, + resolved_by_user_id: null, + resolution_note: '', +} + +const summaryResponse = { + total: 2, + open_total: 1, + resolved_total: 1, + by_type: { + oidc_synthetic_email_requires_manual_recovery: 2, + }, +} + +const listResponse = { + items: [sampleReport], + total: 1, + page: 1, + page_size: 20, + pages: 1, +} + +const AppLayoutStub = defineComponent({ + setup(_, { slots }) { + return () => h('div', slots.default?.()) + }, +}) + +const TablePageLayoutStub = defineComponent({ + setup(_, { slots }) { + return () => h('div', [ + slots.actions?.(), + slots.filters?.(), + slots.table?.(), + slots.default?.(), + slots.pagination?.(), + ]) + }, +}) + +const DataTableStub = defineComponent({ + props: { + columns: { type: Array, default: () => [] }, + data: { type: Array, default: () => [] }, + loading: { type: Boolean, default: false }, + }, + setup(props, { slots }) { + return () => h('div', { 'data-test': 'data-table' }, [ + props.loading + ? h('div', 'loading') + : (props.data as Array>).map((row) => + h( + 'div', + { key: String(row.id ?? row.report_key) }, + (props.columns as Array<{ key: string }>).map((column) => { + const slot = slots[`cell-${column.key}`] + return h( + 'div', + { key: column.key, [`data-test-cell`]: `${String(row.id)}-${column.key}` }, + slot + ? slot({ row, value: row[column.key] }) + : String(row[column.key] ?? '') + ) + }) + ) + ), + ]) + }, +}) + +const PaginationStub = defineComponent({ + props: { + total: { type: Number, required: true }, + page: { type: Number, required: true }, + pageSize: { type: Number, required: true }, + }, + emits: ['update:page', 'update:pageSize'], + setup(props, { emit }) { + return () => h('div', { 'data-test': 'pagination' }, [ + h('button', { + type: 'button', + 'data-test': 'next-page', + onClick: () => emit('update:page', props.page + 1), + }, 'next'), + h('button', { + type: 'button', + 'data-test': 'page-size-50', + onClick: () => emit('update:pageSize', 50), + }, '50'), + ]) + }, +}) + +describe('AuthIdentityMigrationReportsView', () => { + beforeEach(() => { + getAuthIdentityMigrationReportSummary.mockReset() + listAuthIdentityMigrationReports.mockReset() + resolveAuthIdentityMigrationReport.mockReset() + showError.mockReset() + showSuccess.mockReset() + + getAuthIdentityMigrationReportSummary.mockResolvedValue(summaryResponse) + listAuthIdentityMigrationReports.mockResolvedValue(listResponse) + resolveAuthIdentityMigrationReport.mockResolvedValue({ + ...sampleReport, + resolved_at: '2026-04-20T02:00:00Z', + resolved_by_user_id: 100, + resolution_note: 'resolved by admin', + }) + }) + + const mountView = () => + mount(AuthIdentityMigrationReportsView, { + global: { + stubs: { + AppLayout: AppLayoutStub, + TablePageLayout: TablePageLayoutStub, + DataTable: DataTableStub, + Pagination: PaginationStub, + Icon: true, + }, + }, + }) + + it('loads summary and first page of reports on mount', async () => { + const wrapper = mountView() + + await flushPromises() + + expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) + expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ + page: 1, + pageSize: 20, + reportType: '', + }) + expect(wrapper.get('[data-test="summary-total"]').text()).toContain('2') + expect(wrapper.get('[data-test="summary-open"]').text()).toContain('1') + expect(wrapper.get('[data-test="summary-resolved"]').text()).toContain('1') + expect(wrapper.text()).toContain('legacy@example.invalid') + }) + + it('reloads list when the report type filter changes', async () => { + const wrapper = mountView() + + await flushPromises() + + listAuthIdentityMigrationReports.mockClear() + + await wrapper.get('[data-test="report-type-filter"]').setValue( + 'oidc_synthetic_email_requires_manual_recovery' + ) + await flushPromises() + + expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ + page: 1, + pageSize: 20, + reportType: 'oidc_synthetic_email_requires_manual_recovery', + }) + }) + + it('submits resolve note for the selected report and refreshes data', async () => { + const wrapper = mountView() + + await flushPromises() + + getAuthIdentityMigrationReportSummary.mockClear() + listAuthIdentityMigrationReports.mockClear() + + await wrapper.get('[data-test="select-report-1"]').trigger('click') + await wrapper.get('[data-test="resolution-note"]').setValue('resolved by admin') + await wrapper.get('[data-test="resolve-submit"]').trigger('click') + await flushPromises() + + expect(resolveAuthIdentityMigrationReport).toHaveBeenCalledWith(1, 'resolved by admin') + expect(showSuccess).toHaveBeenCalled() + expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) + expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ + page: 1, + pageSize: 20, + reportType: '', + }) + }) +}) -- GitLab From 9204145746623d382f2fe5f9f2f44723042942dd Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 00:11:03 +0800 Subject: [PATCH 080/265] Close profile identity and avatar loop --- backend/internal/handler/auth_handler.go | 1 + .../internal/handler/auth_linuxdo_oauth.go | 2 +- .../handler/auth_oauth_pending_flow.go | 15 +- .../handler/auth_oauth_pending_flow_test.go | 145 ++++++++++++- backend/internal/handler/auth_oidc_oauth.go | 2 +- backend/internal/handler/auth_wechat_oauth.go | 2 +- backend/internal/handler/user_handler.go | 112 +++++++++- backend/internal/handler/user_handler_test.go | 79 +++++++ backend/internal/service/user_service.go | 75 +++++-- frontend/src/api/user.ts | 1 + .../user/profile/ProfileInfoCard.vue | 202 +++++++++++++++++- .../profile/__tests__/ProfileInfoCard.spec.ts | 172 +++++++++++++++ frontend/src/i18n/locales/en.ts | 14 ++ frontend/src/i18n/locales/zh.ts | 14 ++ 14 files changed, 801 insertions(+), 35 deletions(-) create mode 100644 frontend/src/components/user/profile/__tests__/ProfileInfoCard.spec.ts diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index e4697609..b984a436 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -296,6 +296,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { c.Request.Context(), h.entClient(), h.authService, + h.userService, pendingSession, decision, &user.ID, diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index a0760a3b..175b1e1f 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -495,7 +495,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 1b3c1380..94186858 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -852,6 +852,7 @@ func applyPendingOAuthBinding( ctx context.Context, client *dbent.Client, authService *service.AuthService, + userService *service.UserService, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, @@ -938,6 +939,12 @@ func applyPendingOAuthBinding( } } + if decision != nil && decision.AdoptAvatar && adoptedAvatarURL != "" && userService != nil { + if _, err := userService.SetAvatar(txCtx, targetUserID, adoptedAvatarURL); err != nil { + return err + } + } + return tx.Commit() } @@ -945,6 +952,7 @@ func applyPendingOAuthAdoption( ctx context.Context, client *dbent.Client, authService *service.AuthService, + userService *service.UserService, session *dbent.PendingAuthSession, decision *dbent.IdentityAdoptionDecision, overrideUserID *int64, @@ -953,6 +961,7 @@ func applyPendingOAuthAdoption( ctx, client, authService, + userService, session, decision, overrideUserID, @@ -1092,7 +1101,7 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { }) return } - if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID, true, true); err != nil { + if err := applyPendingOAuthBinding(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID, true, true); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } @@ -1188,7 +1197,7 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) response.ErrorFrom(c, err) return } - if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, session, decision, &user.ID, true, false); err != nil { + if err := applyPendingOAuthBinding(c.Request.Context(), client, h.authService, h.userService, session, decision, &user.ID, true, false); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_BIND_APPLY_FAILED", "failed to bind pending oauth identity").WithCause(err)) return } @@ -1278,7 +1287,7 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, session.TargetUserID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, session.TargetUserID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 8c468fdc..2521186e 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -152,6 +152,11 @@ func TestExchangePendingOAuthCompletionPreviewThenFinalizeAppliesAdoptionDecisio require.Equal(t, "Alice Example", identity.Metadata["display_name"]) require.Equal(t, "https://cdn.example/alice.png", identity.Metadata["avatar_url"]) + avatar := loadUserAvatarRecord(t, client, userEntity.ID) + require.NotNil(t, avatar) + require.Equal(t, "remote_url", avatar.StorageProvider) + require.Equal(t, "https://cdn.example/alice.png", avatar.URL) + decision, err := client.IdentityAdoptionDecision.Query(). Where(identityadoptiondecision.PendingAuthSessionIDEQ(session.ID)). Only(ctx) @@ -1242,6 +1247,18 @@ CREATE TABLE IF NOT EXISTS user_provider_default_grants ( UNIQUE(user_id, provider_type, grant_reason) )`) require.NoError(t, err) + _, err = db.Exec(` +CREATE TABLE IF NOT EXISTS user_avatars ( + user_id INTEGER PRIMARY KEY, + storage_provider TEXT NOT NULL, + storage_key TEXT NOT NULL DEFAULT '', + url TEXT NOT NULL, + content_type TEXT NOT NULL DEFAULT '', + byte_size INTEGER NOT NULL DEFAULT 0, + sha256 TEXT NOT NULL DEFAULT '', + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +)`) + require.NoError(t, err) drv := entsql.OpenDB(dialect.SQLite, db) client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) @@ -1492,6 +1509,35 @@ func decodeJSONBody(t *testing.T, recorder *httptest.ResponseRecorder) map[strin return payload } +type oauthPendingFlowAvatarRecord struct { + StorageProvider string + URL string +} + +func loadUserAvatarRecord(t *testing.T, client *dbent.Client, userID int64) *oauthPendingFlowAvatarRecord { + t.Helper() + + var rows entsql.Rows + err := client.Driver().Query( + context.Background(), + `SELECT storage_provider, url FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ) + require.NoError(t, err) + defer rows.Close() + + if !rows.Next() { + require.NoError(t, rows.Err()) + return nil + } + + var record oauthPendingFlowAvatarRecord + require.NoError(t, rows.Scan(&record.StorageProvider, &record.URL)) + require.NoError(t, rows.Err()) + return &record +} + func countProviderGrantRecords( t *testing.T, client *dbent.Client, @@ -1604,16 +1650,95 @@ func (r *oauthPendingFlowUserRepo) Delete(ctx context.Context, id int64) error { return r.client.User.DeleteOneID(id).Exec(ctx) } -func (r *oauthPendingFlowUserRepo) GetUserAvatar(context.Context, int64) (*service.UserAvatar, error) { - return nil, nil +func (r *oauthPendingFlowUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var rows entsql.Rows + if err := driver.Query( + ctx, + `SELECT storage_provider, storage_key, url, content_type, byte_size, sha256 FROM user_avatars WHERE user_id = ?`, + []any{userID}, + &rows, + ); err != nil { + return nil, err + } + defer rows.Close() + + if !rows.Next() { + return nil, rows.Err() + } + + var avatar service.UserAvatar + if err := rows.Scan( + &avatar.StorageProvider, + &avatar.StorageKey, + &avatar.URL, + &avatar.ContentType, + &avatar.ByteSize, + &avatar.SHA256, + ); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return &avatar, nil } -func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(context.Context, int64, service.UpsertUserAvatarInput) (*service.UserAvatar, error) { - panic("unexpected UpsertUserAvatar call") +func (r *oauthPendingFlowUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input service.UpsertUserAvatarInput) (*service.UserAvatar, error) { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + if err := driver.Exec( + ctx, + `INSERT INTO user_avatars (user_id, storage_provider, storage_key, url, content_type, byte_size, sha256, updated_at) +VALUES (?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) +ON CONFLICT(user_id) DO UPDATE SET + storage_provider = excluded.storage_provider, + storage_key = excluded.storage_key, + url = excluded.url, + content_type = excluded.content_type, + byte_size = excluded.byte_size, + sha256 = excluded.sha256, + updated_at = CURRENT_TIMESTAMP`, + []any{ + userID, + input.StorageProvider, + input.StorageKey, + input.URL, + input.ContentType, + input.ByteSize, + input.SHA256, + }, + &result, + ); err != nil { + return nil, err + } + + return &service.UserAvatar{ + StorageProvider: input.StorageProvider, + StorageKey: input.StorageKey, + URL: input.URL, + ContentType: input.ContentType, + ByteSize: input.ByteSize, + SHA256: input.SHA256, + }, nil } -func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(context.Context, int64) error { - return nil +func (r *oauthPendingFlowUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error { + driver := r.client.Driver() + if tx := dbent.TxFromContext(ctx); tx != nil { + driver = tx.Client().Driver() + } + + var result entsql.Result + return driver.Exec(ctx, `DELETE FROM user_avatars WHERE user_id = ?`, []any{userID}, &result) } func (r *oauthPendingFlowUserRepo) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { @@ -1636,6 +1761,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int panic("unexpected UpdateConcurrency call") } +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} + +func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} + func (r *oauthPendingFlowUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { count, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Count(ctx) return count > 0, err diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 70424ec5..0f9f1895 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -537,7 +537,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 4d25a763..b078b804 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -517,7 +517,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, session, decision, &user.ID); err != nil { + if err := applyPendingOAuthAdoption(c.Request.Context(), h.entClient(), h.authService, h.userService, session, decision, &user.ID); err != nil { response.ErrorFrom(c, infraerrors.InternalServer("PENDING_AUTH_ADOPTION_APPLY_FAILED", "failed to apply oauth profile adoption").WithCause(err)) return } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 9dcff828..b1ade5c0 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -2,6 +2,7 @@ package handler import ( "context" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -43,8 +44,24 @@ type UpdateProfileRequest struct { type userProfileResponse struct { dto.User - AvatarURL string `json:"avatar_url,omitempty"` - Identities service.UserIdentitySummarySet `json:"identities"` + AvatarURL string `json:"avatar_url,omitempty"` + AvatarSource *userProfileSourceContext `json:"avatar_source,omitempty"` + UsernameSource *userProfileSourceContext `json:"username_source,omitempty"` + DisplayNameSource *userProfileSourceContext `json:"display_name_source,omitempty"` + NicknameSource *userProfileSourceContext `json:"nickname_source,omitempty"` + ProfileSources map[string]*userProfileSourceContext `json:"profile_sources,omitempty"` + Identities service.UserIdentitySummarySet `json:"identities"` + AuthBindings map[string]service.UserIdentitySummary `json:"auth_bindings"` + IdentityBindings map[string]service.UserIdentitySummary `json:"identity_bindings"` + EmailBound bool `json:"email_bound"` + LinuxDoBound bool `json:"linuxdo_bound"` + OIDCBound bool `json:"oidc_bound"` + WeChatBound bool `json:"wechat_bound"` +} + +type userProfileSourceContext struct { + Provider string `json:"provider,omitempty"` + Source string `json:"source,omitempty"` } // GetProfile handles getting user profile @@ -335,9 +352,94 @@ func userProfileResponseFromService(user *service.User, identities service.UserI if base == nil { return userProfileResponse{} } + bindings := userProfileBindingMap(identities) + profileSources, avatarSource, usernameSource := inferUserProfileSources(user, identities) return userProfileResponse{ - User: *base, - AvatarURL: user.AvatarURL, - Identities: identities, + User: *base, + AvatarURL: user.AvatarURL, + AvatarSource: avatarSource, + UsernameSource: usernameSource, + DisplayNameSource: usernameSource, + NicknameSource: usernameSource, + ProfileSources: profileSources, + Identities: identities, + AuthBindings: bindings, + IdentityBindings: bindings, + EmailBound: identities.Email.Bound, + LinuxDoBound: identities.LinuxDo.Bound, + OIDCBound: identities.OIDC.Bound, + WeChatBound: identities.WeChat.Bound, + } +} + +func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary { + return map[string]service.UserIdentitySummary{ + "email": identities.Email, + "linuxdo": identities.LinuxDo, + "oidc": identities.OIDC, + "wechat": identities.WeChat, + } +} + +func inferUserProfileSources(user *service.User, identities service.UserIdentitySummarySet) ( + map[string]*userProfileSourceContext, + *userProfileSourceContext, + *userProfileSourceContext, +) { + if user == nil { + return nil, nil, nil + } + + thirdParty := thirdPartyIdentityProviders(identities) + var avatarSource *userProfileSourceContext + if strings.TrimSpace(user.AvatarURL) != "" && len(thirdParty) == 1 { + avatarSource = buildUserProfileSourceContext(thirdParty[0].Provider) + } + + usernameValue := strings.TrimSpace(user.Username) + var usernameSource *userProfileSourceContext + for _, summary := range thirdParty { + if usernameValue != "" && usernameValue == strings.TrimSpace(summary.DisplayName) { + usernameSource = buildUserProfileSourceContext(summary.Provider) + break + } + } + if usernameSource == nil && usernameValue != "" && len(thirdParty) == 1 { + usernameSource = buildUserProfileSourceContext(thirdParty[0].Provider) + } + + profileSources := map[string]*userProfileSourceContext{} + if avatarSource != nil { + profileSources["avatar"] = avatarSource + } + if usernameSource != nil { + profileSources["username"] = usernameSource + profileSources["display_name"] = usernameSource + profileSources["nickname"] = usernameSource + } + if len(profileSources) == 0 { + return nil, avatarSource, usernameSource + } + return profileSources, avatarSource, usernameSource +} + +func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary { + out := make([]service.UserIdentitySummary, 0, 3) + for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} { + if summary.Bound { + out = append(out, summary) + } + } + return out +} + +func buildUserProfileSourceContext(provider string) *userProfileSourceContext { + provider = strings.TrimSpace(provider) + if provider == "" { + return nil + } + return &userProfileSourceContext{ + Provider: provider, + Source: provider, } } diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index b71846c1..1216f9c4 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -92,6 +92,12 @@ func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int6 func (s *userHandlerRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } +func (s *userHandlerRepoStub) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { + return map[int64]*time.Time{}, nil +} +func (s *userHandlerRepoStub) GetLatestUsedAtByUserID(context.Context, int64) (*time.Time, error) { + return nil, nil +} func (s *userHandlerRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { return nil } @@ -230,6 +236,79 @@ func TestUserHandlerGetProfileReturnsIdentitySummaries(t *testing.T) { require.Contains(t, resp.Data.Identities.WeChat.BindStartPath, "/api/v1/auth/oauth/wechat/start") } +func TestUserHandlerGetProfileReturnsLegacyCompatibilityFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + verifiedAt := time.Date(2026, 4, 20, 8, 30, 0, 0, time.UTC) + repo := &userHandlerRepoStub{ + user: &service.User{ + ID: 21, + Email: "legacy-profile@example.com", + Username: "linuxdo-handle", + Role: service.RoleUser, + Status: service.StatusActive, + AvatarURL: "https://cdn.example.com/linuxdo.png", + AvatarSource: "remote_url", + }, + identities: []service.UserAuthIdentityRecord{ + { + ProviderType: "linuxdo", + ProviderKey: "linuxdo", + ProviderSubject: "linuxdo-subject-21", + VerifiedAt: &verifiedAt, + Metadata: map[string]any{ + "username": "linuxdo-handle", + }, + }, + }, + } + handler := NewUserHandler(service.NewUserService(repo, nil, nil, nil), nil, nil) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/user/profile", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 21}) + + handler.GetProfile(c) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data map[string]any `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Equal(t, true, resp.Data["email_bound"]) + require.Equal(t, true, resp.Data["linuxdo_bound"]) + require.Equal(t, false, resp.Data["oidc_bound"]) + require.Equal(t, false, resp.Data["wechat_bound"]) + require.Equal(t, "https://cdn.example.com/linuxdo.png", resp.Data["avatar_url"]) + + authBindings, ok := resp.Data["auth_bindings"].(map[string]any) + require.True(t, ok) + linuxdoBinding, ok := authBindings["linuxdo"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, linuxdoBinding["bound"]) + require.Equal(t, "linuxdo", linuxdoBinding["provider"]) + + identityBindings, ok := resp.Data["identity_bindings"].(map[string]any) + require.True(t, ok) + emailBinding, ok := identityBindings["email"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, emailBinding["bound"]) + + avatarSource, ok := resp.Data["avatar_source"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", avatarSource["provider"]) + + profileSources, ok := resp.Data["profile_sources"].(map[string]any) + require.True(t, ok) + usernameSource, ok := profileSources["username"].(map[string]any) + require.True(t, ok) + require.Equal(t, "linuxdo", usernameSource["provider"]) +} + func TestUserHandlerStartIdentityBindingReturnsAuthorizeURL(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index c52f91bb..c106a3f5 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -65,6 +65,8 @@ type UserRepository interface { List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) + GetLatestUsedAtByUserIDs(ctx context.Context, userIDs []int64) (map[int64]*time.Time, error) + GetLatestUsedAtByUserID(ctx context.Context, userID int64) (*time.Time, error) UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error @@ -159,6 +161,33 @@ type userAuthIdentityReader interface { ListUserAuthIdentities(ctx context.Context, userID int64) ([]UserAuthIdentityRecord, error) } +type emailAuthIdentitySynchronizer interface { + EnsureEmailAuthIdentity(ctx context.Context, userID int64, email string) error + ReplaceEmailAuthIdentity(ctx context.Context, userID int64, oldEmail, newEmail string) error +} + +func ensureEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, email string) error { + syncer, ok := repo.(emailAuthIdentitySynchronizer) + if !ok { + return nil + } + return syncer.EnsureEmailAuthIdentity(ctx, userID, email) +} + +func replaceEmailAuthIdentitySync(ctx context.Context, repo UserRepository, userID int64, oldEmail, newEmail string) error { + oldNormalized := strings.ToLower(strings.TrimSpace(oldEmail)) + newNormalized := strings.ToLower(strings.TrimSpace(newEmail)) + if oldNormalized == newNormalized { + return nil + } + + syncer, ok := repo.(emailAuthIdentitySynchronizer) + if !ok { + return nil + } + return syncer.ReplaceEmailAuthIdentity(ctx, userID, oldEmail, newEmail) +} + // ChangePasswordRequest 修改密码请求 type ChangePasswordRequest struct { CurrentPassword string `json:"current_password"` @@ -252,6 +281,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return nil, fmt.Errorf("get user: %w", err) } oldConcurrency := user.Concurrency + oldEmail := user.Email // 更新字段 if req.Email != nil { @@ -271,24 +301,11 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat } if req.AvatarURL != nil { - avatarValue := strings.TrimSpace(*req.AvatarURL) - switch { - case avatarValue == "": - if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { - return nil, fmt.Errorf("delete avatar: %w", err) - } - applyUserAvatar(user, nil) - default: - avatarInput, err := normalizeUserAvatarInput(avatarValue) - if err != nil { - return nil, err - } - avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) - if err != nil { - return nil, fmt.Errorf("upsert avatar: %w", err) - } - applyUserAvatar(user, avatar) + avatar, err := s.SetAvatar(ctx, userID, *req.AvatarURL) + if err != nil { + return nil, err } + applyUserAvatar(user, avatar) } if req.Concurrency != nil { @@ -309,6 +326,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } + if err := replaceEmailAuthIdentitySync(ctx, s.userRepo, user.ID, oldEmail, user.Email); err != nil { + return nil, fmt.Errorf("sync email auth identity: %w", err) + } if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } @@ -316,6 +336,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat return user, nil } +func (s *UserService) SetAvatar(ctx context.Context, userID int64, raw string) (*UserAvatar, error) { + avatarValue := strings.TrimSpace(raw) + if avatarValue == "" { + if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil { + return nil, fmt.Errorf("delete avatar: %w", err) + } + return nil, nil + } + + avatarInput, err := normalizeUserAvatarInput(avatarValue) + if err != nil { + return nil, err + } + + avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput) + if err != nil { + return nil, fmt.Errorf("upsert avatar: %w", err) + } + return avatar, nil +} + func applyUserAvatar(user *User, avatar *UserAvatar) { if user == nil { return diff --git a/frontend/src/api/user.ts b/frontend/src/api/user.ts index 1f6e4cd9..c7b1e503 100644 --- a/frontend/src/api/user.ts +++ b/frontend/src/api/user.ts @@ -23,6 +23,7 @@ export async function getProfile(): Promise { */ export async function updateProfile(profile: { username?: string + avatar_url?: string | null balance_notify_enabled?: boolean balance_notify_threshold?: number | null balance_notify_extra_emails?: NotifyEmailEntry[] diff --git a/frontend/src/components/user/profile/ProfileInfoCard.vue b/frontend/src/components/user/profile/ProfileInfoCard.vue index e82ae229..d6273431 100644 --- a/frontend/src/components/user/profile/ProfileInfoCard.vue +++ b/frontend/src/components/user/profile/ProfileInfoCard.vue @@ -61,6 +61,71 @@
+
+
+
+

+ {{ t('profile.avatar.title') }} +

+

+ {{ t('profile.avatar.description') }} +

+
+ +
+ +
+ + -
- - -
- -
-

- {{ copy.remediationTitle }} -

-

- {{ copy.remediationSubtitle }} -

- -
-
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- - -
-
-
- -
- - - - diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue index 93cfdbbe..39c9b377 100644 --- a/frontend/src/views/admin/UsersView.vue +++ b/frontend/src/views/admin/UsersView.vue @@ -712,8 +712,8 @@ const allColumns = computed(() => [ { key: 'usage', label: t('admin.users.columns.usage'), sortable: false }, { key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true }, { key: 'status', label: t('admin.users.columns.status'), sortable: true }, - { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true }, { key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true }, + { key: 'last_used_at', label: t('admin.users.columns.lastUsed'), sortable: true }, { key: 'created_at', label: t('admin.users.columns.created'), sortable: true }, { key: 'actions', label: t('admin.users.columns.actions'), sortable: false } ]) diff --git a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts b/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts deleted file mode 100644 index 20f57fa1..00000000 --- a/frontend/src/views/admin/__tests__/AuthIdentityMigrationReportsView.spec.ts +++ /dev/null @@ -1,303 +0,0 @@ -import { beforeEach, describe, expect, it, vi } from 'vitest' -import { flushPromises, mount } from '@vue/test-utils' -import { defineComponent, h } from 'vue' - -import AuthIdentityMigrationReportsView from '../AuthIdentityMigrationReportsView.vue' - -const { - bindUserAuthIdentity, - getAuthIdentityMigrationReportSummary, - listAuthIdentityMigrationReports, - resolveAuthIdentityMigrationReport, -} = vi.hoisted(() => ({ - bindUserAuthIdentity: vi.fn(), - getAuthIdentityMigrationReportSummary: vi.fn(), - listAuthIdentityMigrationReports: vi.fn(), - resolveAuthIdentityMigrationReport: vi.fn(), -})) - -const { showError, showSuccess } = vi.hoisted(() => ({ - showError: vi.fn(), - showSuccess: vi.fn(), -})) - -vi.mock('@/api/admin', () => ({ - adminAPI: { - users: { - bindUserAuthIdentity, - getAuthIdentityMigrationReportSummary, - listAuthIdentityMigrationReports, - resolveAuthIdentityMigrationReport, - }, - }, -})) - -vi.mock('@/stores/app', () => ({ - useAppStore: () => ({ - showError, - showSuccess, - }), -})) - -vi.mock('vue-i18n', async () => { - const actual = await vi.importActual('vue-i18n') - return { - ...actual, - useI18n: () => ({ - locale: { value: 'en' }, - t: (key: string) => key, - }), - } -}) - -vi.mock('@/utils/format', () => ({ - formatDateTime: (value: string | null | undefined) => value ?? '', -})) - -const sampleReport = { - id: 1, - report_type: 'oidc_synthetic_email_requires_manual_recovery', - report_key: 'legacy@example.invalid', - details: { - user_id: 42, - legacy_email: 'legacy@example.invalid', - provider_key: 'https://issuer.example', - provider_subject: 'subject-123', - }, - created_at: '2026-04-20T01:02:03Z', - resolved_at: null, - resolved_by_user_id: null, - resolution_note: '', -} - -const summaryResponse = { - total: 2, - open_total: 1, - resolved_total: 1, - by_type: { - oidc_synthetic_email_requires_manual_recovery: 2, - }, -} - -const listResponse = { - items: [sampleReport], - total: 1, - page: 1, - page_size: 20, - pages: 1, -} - -const AppLayoutStub = defineComponent({ - setup(_, { slots }) { - return () => h('div', slots.default?.()) - }, -}) - -const TablePageLayoutStub = defineComponent({ - setup(_, { slots }) { - return () => h('div', [ - slots.actions?.(), - slots.filters?.(), - slots.table?.(), - slots.default?.(), - slots.pagination?.(), - ]) - }, -}) - -const DataTableStub = defineComponent({ - props: { - columns: { type: Array, default: () => [] }, - data: { type: Array, default: () => [] }, - loading: { type: Boolean, default: false }, - }, - setup(props, { slots }) { - return () => h('div', { 'data-test': 'data-table' }, [ - props.loading - ? h('div', 'loading') - : (props.data as Array>).map((row) => - h( - 'div', - { key: String(row.id ?? row.report_key) }, - (props.columns as Array<{ key: string }>).map((column) => { - const slot = slots[`cell-${column.key}`] - return h( - 'div', - { key: column.key, [`data-test-cell`]: `${String(row.id)}-${column.key}` }, - slot - ? slot({ row, value: row[column.key] }) - : String(row[column.key] ?? '') - ) - }) - ) - ), - ]) - }, -}) - -const PaginationStub = defineComponent({ - props: { - total: { type: Number, required: true }, - page: { type: Number, required: true }, - pageSize: { type: Number, required: true }, - }, - emits: ['update:page', 'update:pageSize'], - setup(props, { emit }) { - return () => h('div', { 'data-test': 'pagination' }, [ - h('button', { - type: 'button', - 'data-test': 'next-page', - onClick: () => emit('update:page', props.page + 1), - }, 'next'), - h('button', { - type: 'button', - 'data-test': 'page-size-50', - onClick: () => emit('update:pageSize', 50), - }, '50'), - ]) - }, -}) - -describe('AuthIdentityMigrationReportsView', () => { - beforeEach(() => { - getAuthIdentityMigrationReportSummary.mockReset() - listAuthIdentityMigrationReports.mockReset() - resolveAuthIdentityMigrationReport.mockReset() - bindUserAuthIdentity.mockReset() - showError.mockReset() - showSuccess.mockReset() - - getAuthIdentityMigrationReportSummary.mockResolvedValue(summaryResponse) - listAuthIdentityMigrationReports.mockResolvedValue(listResponse) - resolveAuthIdentityMigrationReport.mockResolvedValue({ - ...sampleReport, - resolved_at: '2026-04-20T02:00:00Z', - resolved_by_user_id: 100, - resolution_note: 'resolved by admin', - }) - bindUserAuthIdentity.mockResolvedValue({ - identity_id: 77, - provider_type: 'oidc', - provider_key: 'https://issuer.example', - provider_subject: 'subject-123', - }) - }) - - const mountView = () => - mount(AuthIdentityMigrationReportsView, { - global: { - stubs: { - AppLayout: AppLayoutStub, - TablePageLayout: TablePageLayoutStub, - DataTable: DataTableStub, - Pagination: PaginationStub, - Icon: true, - }, - }, - }) - - it('loads summary and first page of reports on mount', async () => { - const wrapper = mountView() - - await flushPromises() - - expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) - expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ - page: 1, - pageSize: 20, - reportType: '', - }) - expect(wrapper.get('[data-test="summary-total"]').text()).toContain('2') - expect(wrapper.get('[data-test="summary-open"]').text()).toContain('1') - expect(wrapper.get('[data-test="summary-resolved"]').text()).toContain('1') - expect(wrapper.text()).toContain('legacy@example.invalid') - }) - - it('reloads list when the report type filter changes', async () => { - const wrapper = mountView() - - await flushPromises() - - listAuthIdentityMigrationReports.mockClear() - - await wrapper.get('[data-test="report-type-filter"]').setValue( - 'oidc_synthetic_email_requires_manual_recovery' - ) - await flushPromises() - - expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ - page: 1, - pageSize: 20, - reportType: 'oidc_synthetic_email_requires_manual_recovery', - }) - }) - - it('submits resolve note for the selected report and refreshes data', async () => { - const wrapper = mountView() - - await flushPromises() - - getAuthIdentityMigrationReportSummary.mockClear() - listAuthIdentityMigrationReports.mockClear() - - await wrapper.get('[data-test="select-report-1"]').trigger('click') - await wrapper.get('[data-test="resolution-note"]').setValue('resolved by admin') - await wrapper.get('[data-test="resolve-submit"]').trigger('click') - await flushPromises() - - expect(resolveAuthIdentityMigrationReport).toHaveBeenCalledWith(1, 'resolved by admin') - expect(showSuccess).toHaveBeenCalled() - expect(getAuthIdentityMigrationReportSummary).toHaveBeenCalledTimes(1) - expect(listAuthIdentityMigrationReports).toHaveBeenCalledWith({ - page: 1, - pageSize: 20, - reportType: '', - }) - }) - - it('pre-fills and submits remediation binding for the selected report', async () => { - const wrapper = mountView() - - await flushPromises() - await wrapper.get('[data-test="select-report-1"]').trigger('click') - await flushPromises() - - expect((wrapper.get('[data-test="remediation-user-id"]').element as HTMLInputElement).value).toBe('42') - expect((wrapper.get('[data-test="remediation-provider-type"]').element as HTMLInputElement).value).toBe('oidc') - expect((wrapper.get('[data-test="remediation-provider-key"]').element as HTMLInputElement).value).toBe( - 'https://issuer.example' - ) - expect((wrapper.get('[data-test="remediation-provider-subject"]').element as HTMLInputElement).value).toBe( - 'subject-123' - ) - - await wrapper.get('[data-test="remediation-submit"]').trigger('click') - await flushPromises() - - expect(bindUserAuthIdentity).toHaveBeenCalledWith(42, { - provider_type: 'oidc', - provider_key: 'https://issuer.example', - provider_subject: 'subject-123', - issuer: undefined, - metadata: {}, - }) - expect(showSuccess).toHaveBeenCalled() - }) - - it('keeps report type filter options available from list data when summary fails', async () => { - getAuthIdentityMigrationReportSummary.mockRejectedValueOnce(new Error('summary failed')) - listAuthIdentityMigrationReports.mockResolvedValueOnce(listResponse) - - const wrapper = mountView() - - await flushPromises() - - const options = wrapper - .get('[data-test="report-type-filter"]') - .findAll('option') - .map((node) => node.element.value) - - expect(showError).toHaveBeenCalled() - expect(options).toContain('oidc_synthetic_email_requires_manual_recovery') - }) -}) diff --git a/frontend/src/views/admin/__tests__/UsersView.spec.ts b/frontend/src/views/admin/__tests__/UsersView.spec.ts index d9076777..532d89f1 100644 --- a/frontend/src/views/admin/__tests__/UsersView.spec.ts +++ b/frontend/src/views/admin/__tests__/UsersView.spec.ts @@ -70,7 +70,6 @@ const createAdminUser = (): AdminUser => ({ created_at: '2026-04-17T00:00:00Z', updated_at: '2026-04-17T00:00:00Z', notes: '', - last_login_at: '2026-04-16T01:00:00Z', last_active_at: '2026-04-16T02:00:00Z', last_used_at: '2026-04-17T02:00:00Z', current_concurrency: 0 @@ -113,7 +112,7 @@ describe('admin UsersView', () => { getBatchUserAttributes.mockResolvedValue({ values: {} }) }) - it('shows active and used activity columns, hides last_login_at, and requests last_used_at sort', async () => { + it('shows active, used, and created activity columns in order and requests last_used_at sort', async () => { const wrapper = mount(UsersView, { global: { stubs: { @@ -145,9 +144,9 @@ describe('admin UsersView', () => { await flushPromises() const columns = wrapper.get('[data-test="columns"]').text() - expect(columns).toContain('last_used_at') - expect(columns).toContain('last_active_at') - expect(columns).not.toContain('last_login_at') + const visibleColumns = columns.split(',') + expect(visibleColumns.slice(-4, -1)).toEqual(['last_active_at', 'last_used_at', 'created_at']) + expect(visibleColumns).not.toContain('last_login_at') await wrapper.get('[data-test="sort-last-used"]').trigger('click') await flushPromises() -- GitLab From ee3f158f4e11dff97e0db0dccecc6885dcfd1b96 Mon Sep 17 00:00:00 2001 From: IanShaw027 Date: Tue, 21 Apr 2026 17:35:12 +0800 Subject: [PATCH 156/265] fix(settings): restore wechat and payment config persistence --- .../internal/handler/admin/setting_handler.go | 99 + backend/internal/handler/auth_wechat_oauth.go | 42 +- .../handler/auth_wechat_oauth_test.go | 100 +- backend/internal/handler/dto/settings.go | 8 + .../handler/setting_handler_public_test.go | 21 +- backend/internal/service/domain_constants.go | 9 + .../service/payment_config_service.go | 9 + .../service/payment_config_service_test.go | 47 +- backend/internal/service/payment_order.go | 16 +- .../service/payment_order_jsapi_test.go | 16 +- .../service/payment_order_result_test.go | 36 +- backend/internal/service/setting_service.go | 168 +- .../service/setting_service_public_test.go | 21 +- .../setting_service_wechat_config_test.go | 77 + backend/internal/service/settings_view.go | 20 + .../__tests__/settings.wechatConnect.spec.ts | 21 + frontend/src/api/admin/settings.ts | 1036 +- frontend/src/views/admin/SettingsView.vue | 8731 ++++++++++------- .../admin/__tests__/SettingsView.spec.ts | 620 +- 19 files changed, 6892 insertions(+), 4205 deletions(-) create mode 100644 backend/internal/service/setting_service_wechat_config_test.go create mode 100644 frontend/src/api/__tests__/settings.wechatConnect.spec.ts diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index f0e91f3a..9bc20771 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -122,6 +122,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { LinuxDoConnectClientID: settings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: settings.WeChatConnectEnabled, + WeChatConnectAppID: settings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured, + WeChatConnectMode: settings.WeChatConnectMode, + WeChatConnectScopes: settings.WeChatConnectScopes, + WeChatConnectRedirectURL: settings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: settings.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: settings.OIDCConnectEnabled, OIDCConnectProviderName: settings.OIDCConnectProviderName, OIDCConnectClientID: settings.OIDCConnectClientID, @@ -246,6 +253,15 @@ type UpdateSettingsRequest struct { LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecret string `json:"wechat_connect_app_secret"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` @@ -509,6 +525,54 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + if req.WeChatConnectEnabled { + req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID) + req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret) + req.WeChatConnectMode = strings.ToLower(strings.TrimSpace(req.WeChatConnectMode)) + req.WeChatConnectScopes = strings.TrimSpace(req.WeChatConnectScopes) + req.WeChatConnectRedirectURL = strings.TrimSpace(req.WeChatConnectRedirectURL) + req.WeChatConnectFrontendRedirectURL = strings.TrimSpace(req.WeChatConnectFrontendRedirectURL) + + if req.WeChatConnectAppID == "" { + response.BadRequest(c, "WeChat App ID is required when enabled") + return + } + if req.WeChatConnectAppSecret == "" { + if previousSettings.WeChatConnectAppSecret == "" { + response.BadRequest(c, "WeChat App Secret is required when enabled") + return + } + req.WeChatConnectAppSecret = previousSettings.WeChatConnectAppSecret + } + if req.WeChatConnectMode == "" { + req.WeChatConnectMode = "open" + } + switch req.WeChatConnectMode { + case "open", "mp": + default: + response.BadRequest(c, "WeChat mode must be open or mp") + return + } + if req.WeChatConnectScopes == "" { + req.WeChatConnectScopes = service.DefaultWeChatConnectScopesForMode(req.WeChatConnectMode) + } + if req.WeChatConnectRedirectURL == "" { + response.BadRequest(c, "WeChat Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.WeChatConnectRedirectURL); err != nil { + response.BadRequest(c, "WeChat Redirect URL must be an absolute http(s) URL") + return + } + if req.WeChatConnectFrontendRedirectURL == "" { + req.WeChatConnectFrontendRedirectURL = "/auth/wechat/callback" + } + if err := config.ValidateFrontendRedirectURL(req.WeChatConnectFrontendRedirectURL); err != nil { + response.BadRequest(c, "WeChat Frontend Redirect URL is invalid") + return + } + } + // Generic OIDC 参数验证 if req.OIDCConnectEnabled { req.OIDCConnectProviderName = strings.TrimSpace(req.OIDCConnectProviderName) @@ -857,6 +921,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: req.WeChatConnectEnabled, + WeChatConnectAppID: req.WeChatConnectAppID, + WeChatConnectAppSecret: req.WeChatConnectAppSecret, + WeChatConnectMode: req.WeChatConnectMode, + WeChatConnectScopes: req.WeChatConnectScopes, + WeChatConnectRedirectURL: req.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: req.OIDCConnectEnabled, OIDCConnectProviderName: req.OIDCConnectProviderName, OIDCConnectClientID: req.OIDCConnectClientID, @@ -1136,6 +1207,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled, + WeChatConnectAppID: updatedSettings.WeChatConnectAppID, + WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured, + WeChatConnectMode: updatedSettings.WeChatConnectMode, + WeChatConnectScopes: updatedSettings.WeChatConnectScopes, + WeChatConnectRedirectURL: updatedSettings.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: updatedSettings.WeChatConnectFrontendRedirectURL, OIDCConnectEnabled: updatedSettings.OIDCConnectEnabled, OIDCConnectProviderName: updatedSettings.OIDCConnectProviderName, OIDCConnectClientID: updatedSettings.OIDCConnectClientID, @@ -1329,6 +1407,27 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { changed = append(changed, "linuxdo_connect_redirect_url") } + if before.WeChatConnectEnabled != after.WeChatConnectEnabled { + changed = append(changed, "wechat_connect_enabled") + } + if before.WeChatConnectAppID != after.WeChatConnectAppID { + changed = append(changed, "wechat_connect_app_id") + } + if req.WeChatConnectAppSecret != "" { + changed = append(changed, "wechat_connect_app_secret") + } + if before.WeChatConnectMode != after.WeChatConnectMode { + changed = append(changed, "wechat_connect_mode") + } + if before.WeChatConnectScopes != after.WeChatConnectScopes { + changed = append(changed, "wechat_connect_scopes") + } + if before.WeChatConnectRedirectURL != after.WeChatConnectRedirectURL { + changed = append(changed, "wechat_connect_redirect_url") + } + if before.WeChatConnectFrontendRedirectURL != after.WeChatConnectFrontendRedirectURL { + changed = append(changed, "wechat_connect_frontend_redirect_url") + } if before.OIDCConnectEnabled != after.OIDCConnectEnabled { changed = append(changed, "oidc_connect_enabled") } diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 5e697fb5..734fb2ef 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "net/url" - "os" "strconv" "strings" "time" @@ -149,7 +148,7 @@ func (h *AuthHandler) WeChatOAuthStart(c *gin.Context) { // WeChatOAuthCallback exchanges the code with WeChat, resolves openid/unionid, // and stores the result in the unified pending-auth flow. func (h *AuthHandler) WeChatOAuthCallback(c *gin.Context) { - frontendCallback := wechatOAuthFrontendCallback() + frontendCallback := h.wechatOAuthFrontendCallback(c.Request.Context()) if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) @@ -859,6 +858,10 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, return wechatOAuthConfig{}, err } + if h == nil || h.settingSvc == nil { + return wechatOAuthConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "wechat oauth settings service not ready") + } + apiBaseURL := "" if h != nil && h.settingSvc != nil { settings, err := h.settingSvc.GetAllSettings(ctx) @@ -867,27 +870,28 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, } } + effective, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err != nil { + return wechatOAuthConfig{}, err + } + if effective.Mode != mode { + return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + cfg := wechatOAuthConfig{ mode: mode, - redirectURI: resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback"), - frontendCallback: wechatOAuthFrontendCallback(), + appID: strings.TrimSpace(effective.AppID), + appSecret: strings.TrimSpace(effective.AppSecret), + redirectURI: firstNonEmpty(strings.TrimSpace(effective.RedirectURL), resolveWeChatOAuthAbsoluteURL(apiBaseURL, c, "/api/v1/auth/oauth/wechat/callback")), + frontendCallback: firstNonEmpty(strings.TrimSpace(effective.FrontendRedirectURL), wechatOAuthDefaultFrontendCB), + scope: firstNonEmpty(strings.TrimSpace(effective.Scopes), service.DefaultWeChatConnectScopesForMode(mode)), } switch mode { case "mp": - cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) - cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) cfg.authorizeURL = "https://open.weixin.qq.com/connect/oauth2/authorize" - cfg.scope = "snsapi_userinfo" default: - cfg.appID = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) - cfg.appSecret = strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) cfg.authorizeURL = "https://open.weixin.qq.com/connect/qrconnect" - cfg.scope = "snsapi_login" - } - - if cfg.appID == "" || cfg.appSecret == "" { - return wechatOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") } if strings.TrimSpace(cfg.redirectURI) == "" { return wechatOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") @@ -896,8 +900,14 @@ func (h *AuthHandler) getWeChatOAuthConfig(ctx context.Context, rawMode string, return cfg, nil } -func wechatOAuthFrontendCallback() string { - return firstNonEmpty(strings.TrimSpace(os.Getenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL")), wechatOAuthDefaultFrontendCB) +func (h *AuthHandler) wechatOAuthFrontendCallback(ctx context.Context) string { + if h != nil && h.settingSvc != nil { + cfg, err := h.settingSvc.GetWeChatConnectOAuthConfig(ctx) + if err == nil && strings.TrimSpace(cfg.FrontendRedirectURL) != "" { + return strings.TrimSpace(cfg.FrontendRedirectURL) + } + } + return wechatOAuthDefaultFrontendCB } func resolveWeChatOAuthMode(rawMode string, c *gin.Context) (string, error) { diff --git a/backend/internal/handler/auth_wechat_oauth_test.go b/backend/internal/handler/auth_wechat_oauth_test.go index cd34f52f..b0fee617 100644 --- a/backend/internal/handler/auth_wechat_oauth_test.go +++ b/backend/internal/handler/auth_wechat_oauth_test.go @@ -33,16 +33,22 @@ import ( ) func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - gin.SetMode(gin.TestMode) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-open-app", + service.SettingKeyWeChatConnectAppSecret: "wx-open-secret", + service.SettingKeyWeChatConnectMode: "open", + service.SettingKeyWeChatConnectScopes: "snsapi_login", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) + defer client.Close() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/wechat/start?mode=open&redirect=/billing", nil) c.Request.Host = "api.example.com" - handler := &AuthHandler{} handler.WeChatOAuthStart(c) require.Equal(t, http.StatusFound, recorder.Code) @@ -60,10 +66,6 @@ func TestWeChatOAuthStartRedirectsAndSetsPendingCookies(t *testing.T) { } func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -124,10 +126,6 @@ func TestWeChatOAuthCallbackCreatesPendingSessionForUnifiedFlow(t *testing.T) { } func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "https://app.example.com/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -151,7 +149,7 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" - handler, client := newWeChatOAuthTestHandler(t, false) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "https://app.example.com/auth/wechat/callback")) defer client.Close() recorder := httptest.NewRecorder() @@ -177,9 +175,6 @@ func TestWeChatOAuthCallbackRejectsMissingUnionID(t *testing.T) { } func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) { - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wx-mp-secret") - originalAccessTokenURL := wechatOAuthAccessTokenURL t.Cleanup(func() { wechatOAuthAccessTokenURL = originalAccessTokenURL @@ -196,7 +191,7 @@ func TestWeChatPaymentOAuthCallbackRedirectsWithOpaqueResumeToken(t *testing.T) defer upstream.Close() wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" - handler, client := newWeChatOAuthTestHandler(t, false) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings("mp", "wx-mp-app", "wx-mp-secret", "/auth/wechat/callback")) defer client.Close() handler.cfg.Totp.EncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" @@ -240,7 +235,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test testCases := []struct { name string mode string - appIDEnv string appID string appSecret string openID string @@ -248,7 +242,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test { name: "open", mode: "open", - appIDEnv: "WECHAT_OAUTH_OPEN_APP_ID", appID: "wx-open-app", appSecret: "wx-open-secret", openID: "openid-open-123", @@ -256,7 +249,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test { name: "mp", mode: "mp", - appIDEnv: "WECHAT_OAUTH_MP_APP_ID", appID: "wx-mp-app", appSecret: "wx-mp-secret", openID: "openid-mp-123", @@ -265,15 +257,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - t.Setenv(tc.appIDEnv, tc.appID) - switch tc.mode { - case "open": - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", tc.appSecret) - case "mp": - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", tc.appSecret) - } - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -297,7 +280,7 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test wechatOAuthAccessTokenURL = upstream.URL + "/sns/oauth2/access_token" wechatOAuthUserInfoURL = upstream.URL + "/sns/userinfo" - handler, client := newWeChatOAuthTestHandler(t, false) + handler, client := newWeChatOAuthTestHandlerWithSettings(t, false, wechatOAuthTestSettings(tc.mode, tc.appID, tc.appSecret, "/auth/wechat/callback")) defer client.Close() currentUser, err := client.User.Create(). @@ -354,10 +337,6 @@ func TestWeChatOAuthCallbackBindUsesUnionCanonicalIdentityAcrossChannels(t *test } func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -436,10 +415,6 @@ func TestWeChatOAuthCallbackBindRejectsCanonicalOwnershipConflict(t *testing.T) } func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -529,10 +504,6 @@ func TestWeChatOAuthCallbackBindRejectsChannelOwnershipConflict(t *testing.T) { } func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -611,10 +582,6 @@ func TestWeChatOAuthCallbackBindRejectsLegacyProviderKeyOwnershipConflict(t *tes } func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -737,10 +704,6 @@ func TestCompleteWeChatOAuthRegistrationAfterInvitationPendingSession(t *testing } func TestWeChatOAuthCallbackRepairsLegacyOpenIDOnlyIdentity(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -900,10 +863,6 @@ func TestCompleteWeChatOAuthRegistrationRejectsAdoptExistingUserSession(t *testi } func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_FRONTEND_REDIRECT_URL", "/auth/wechat/callback") - originalAccessTokenURL := wechatOAuthAccessTokenURL originalUserInfoURL := wechatOAuthUserInfoURL t.Cleanup(func() { @@ -1010,6 +969,22 @@ func TestWeChatOAuthCallbackRepairsLegacyProviderKeyCanonicalIdentity(t *testing } func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandler, *dbent.Client) { + return newWeChatOAuthTestHandlerWithSettings(t, invitationEnabled, nil) +} + +func wechatOAuthTestSettings(mode, appID, secret, frontendRedirect string) map[string]string { + return map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: appID, + service.SettingKeyWeChatConnectAppSecret: secret, + service.SettingKeyWeChatConnectMode: mode, + service.SettingKeyWeChatConnectScopes: service.DefaultWeChatConnectScopesForMode(mode), + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: frontendRedirect, + } +} + +func newWeChatOAuthTestHandlerWithSettings(t *testing.T, invitationEnabled bool, extraSettings map[string]string) (*AuthHandler, *dbent.Client) { t.Helper() db, err := sql.Open("sqlite", "file:auth_wechat_oauth?mode=memory&cache=shared") @@ -1036,12 +1011,17 @@ func newWeChatOAuthTestHandler(t *testing.T, invitationEnabled bool) (*AuthHandl UserConcurrency: 1, }, } - settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{ - values: map[string]string{ - service.SettingKeyRegistrationEnabled: "true", - service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), - }, - }, cfg) + values := map[string]string{ + service.SettingKeyRegistrationEnabled: "true", + service.SettingKeyInvitationCodeEnabled: boolSettingValue(invitationEnabled), + } + for key, value := range wechatOAuthTestSettings("open", "wx-open-app", "wx-open-secret", "/auth/wechat/callback") { + values[key] = value + } + for key, value := range extraSettings { + values[key] = value + } + settingSvc := service.NewSettingService(&wechatOAuthSettingRepoStub{values: values}, cfg) authSvc := service.NewAuthService( client, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index d67b29a0..4c5edfbf 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -51,6 +51,14 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` + WeChatConnectAppID string `json:"wechat_connect_app_id"` + WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"` + WeChatConnectMode string `json:"wechat_connect_mode"` + WeChatConnectScopes string `json:"wechat_connect_scopes"` + WeChatConnectRedirectURL string `json:"wechat_connect_redirect_url"` + WeChatConnectFrontendRedirectURL string `json:"wechat_connect_frontend_redirect_url"` + OIDCConnectEnabled bool `json:"oidc_connect_enabled"` OIDCConnectProviderName string `json:"oidc_connect_provider_name"` OIDCConnectClientID string `json:"oidc_connect_client_id"` diff --git a/backend/internal/handler/setting_handler_public_test.go b/backend/internal/handler/setting_handler_public_test.go index b50c982c..628d9341 100644 --- a/backend/internal/handler/setting_handler_public_test.go +++ b/backend/internal/handler/setting_handler_public_test.go @@ -84,12 +84,17 @@ func TestSettingHandler_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { gin.SetMode(gin.TestMode) - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "") - - h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{}, &config.Config{}), "test-version") + h := NewSettingHandler(service.NewSettingService(&settingHandlerPublicRepoStub{ + values: map[string]string{ + service.SettingKeyWeChatConnectEnabled: "true", + service.SettingKeyWeChatConnectAppID: "wx-mp-app", + service.SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + service.SettingKeyWeChatConnectMode: "mp", + service.SettingKeyWeChatConnectScopes: "snsapi_base", + service.SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + service.SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}), "test-version") recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) @@ -110,6 +115,6 @@ func TestSettingHandler_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t * require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) require.Equal(t, 0, resp.Code) require.True(t, resp.Data.WeChatOAuthEnabled) - require.True(t, resp.Data.WeChatOAuthOpenEnabled) - require.False(t, resp.Data.WeChatOAuthMPEnabled) + require.False(t, resp.Data.WeChatOAuthOpenEnabled) + require.True(t, resp.Data.WeChatOAuthMPEnabled) } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 1dddf77e..4d63cabc 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -111,6 +111,15 @@ const ( SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" + // WeChat Connect OAuth 登录设置 + SettingKeyWeChatConnectEnabled = "wechat_connect_enabled" + SettingKeyWeChatConnectAppID = "wechat_connect_app_id" + SettingKeyWeChatConnectAppSecret = "wechat_connect_app_secret" + SettingKeyWeChatConnectMode = "wechat_connect_mode" + SettingKeyWeChatConnectScopes = "wechat_connect_scopes" + SettingKeyWeChatConnectRedirectURL = "wechat_connect_redirect_url" + SettingKeyWeChatConnectFrontendRedirectURL = "wechat_connect_frontend_redirect_url" + // Generic OIDC OAuth 登录设置 SettingKeyOIDCConnectEnabled = "oidc_connect_enabled" SettingKeyOIDCConnectProviderName = "oidc_connect_provider_name" diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index 34462a3a..2d1f3f42 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -93,6 +93,11 @@ type UpdatePaymentConfigRequest struct { CancelRateLimitWindow *int `json:"cancel_rate_limit_window"` CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"` CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"` + + VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"` + VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"` + VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"` + VisibleMethodWxpayEnabled *bool `json:"payment_visible_method_wxpay_enabled"` } // MethodLimits holds per-payment-type limits. @@ -319,6 +324,10 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow), SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit), SettingCancelWindowMode: derefStr(req.CancelRateLimitMode), + SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource), + SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource), + SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled), + SettingPaymentVisibleMethodWxpayEnabled: formatBoolOrEmpty(req.VisibleMethodWxpayEnabled), } if req.EnabledTypes != nil { m[SettingEnabledPaymentTypes] = strings.Join(req.EnabledTypes, ",") diff --git a/backend/internal/service/payment_config_service_test.go b/backend/internal/service/payment_config_service_test.go index 10919058..d58ee234 100644 --- a/backend/internal/service/payment_config_service_test.go +++ b/backend/internal/service/payment_config_service_test.go @@ -366,7 +366,8 @@ func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client { } type paymentConfigSettingRepoStub struct { - values map[string]string + values map[string]string + updates map[string]string } func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) { @@ -383,10 +384,52 @@ func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []str } return out, nil } -func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error { +func (s *paymentConfigSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error { + s.updates = make(map[string]string, len(values)) + for key, value := range values { + s.updates[key] = value + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + } return nil } func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) { return s.values, nil } func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil } + +func TestUpdatePaymentConfig_PersistsVisibleMethodRouting(t *testing.T) { + repo := &paymentConfigSettingRepoStub{values: map[string]string{}} + svc := &PaymentConfigService{settingRepo: repo} + + alipayEnabled := true + wxpayEnabled := false + err := svc.UpdatePaymentConfig(context.Background(), UpdatePaymentConfigRequest{ + VisibleMethodAlipayEnabled: &alipayEnabled, + VisibleMethodAlipaySource: paymentConfigStrPtr(VisibleMethodSourceEasyPayAlipay), + VisibleMethodWxpayEnabled: &wxpayEnabled, + VisibleMethodWxpaySource: paymentConfigStrPtr(VisibleMethodSourceOfficialWechat), + }) + if err != nil { + t.Fatalf("UpdatePaymentConfig returned error: %v", err) + } + + if repo.values[SettingPaymentVisibleMethodAlipayEnabled] != "true" { + t.Fatalf("alipay enabled = %q, want true", repo.values[SettingPaymentVisibleMethodAlipayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodAlipaySource] != VisibleMethodSourceEasyPayAlipay { + t.Fatalf("alipay source = %q, want %q", repo.values[SettingPaymentVisibleMethodAlipaySource], VisibleMethodSourceEasyPayAlipay) + } + if repo.values[SettingPaymentVisibleMethodWxpayEnabled] != "false" { + t.Fatalf("wxpay enabled = %q, want false", repo.values[SettingPaymentVisibleMethodWxpayEnabled]) + } + if repo.values[SettingPaymentVisibleMethodWxpaySource] != VisibleMethodSourceOfficialWechat { + t.Fatalf("wxpay source = %q, want %q", repo.values[SettingPaymentVisibleMethodWxpaySource], VisibleMethodSourceOfficialWechat) + } +} + +func paymentConfigStrPtr(value string) *string { + return &value +} diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 354f3cd1..01d8c642 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -6,7 +6,6 @@ import ( "log/slog" "math" "net/url" - "os" "strconv" "strings" "time" @@ -512,16 +511,21 @@ func requiresWeChatJSAPICompatibleSelection(req CreateOrderRequest, sel *payment return req.IsWeChatBrowser || strings.TrimSpace(req.OpenID) != "" } -func (s *PaymentService) getWeChatPaymentOAuthCredential(context.Context) (string, string, error) { - appID := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) - appSecret := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) - if appID == "" || appSecret == "" { +func (s *PaymentService) getWeChatPaymentOAuthCredential(ctx context.Context) (string, string, error) { + if s == nil || s.configService == nil || s.configService.settingRepo == nil { + return "", "", infraerrors.ServiceUnavailable( + "WECHAT_PAYMENT_MP_NOT_CONFIGURED", + "wechat in-app payment requires a complete WeChat MP OAuth credential", + ) + } + cfg, err := (&SettingService{settingRepo: s.configService.settingRepo}).GetWeChatConnectOAuthConfig(ctx) + if err != nil || cfg.Mode != "mp" || strings.TrimSpace(cfg.AppID) == "" || strings.TrimSpace(cfg.AppSecret) == "" { return "", "", infraerrors.ServiceUnavailable( "WECHAT_PAYMENT_MP_NOT_CONFIGURED", "wechat in-app payment requires a complete WeChat MP OAuth credential", ) } - return appID, appSecret, nil + return strings.TrimSpace(cfg.AppID), strings.TrimSpace(cfg.AppSecret), nil } func classifyCreatePaymentError(req CreateOrderRequest, providerKey string, err error) error { diff --git a/backend/internal/service/payment_order_jsapi_test.go b/backend/internal/service/payment_order_jsapi_test.go index 08492432..25f209af 100644 --- a/backend/internal/service/payment_order_jsapi_test.go +++ b/backend/internal/service/payment_order_jsapi_test.go @@ -60,10 +60,17 @@ func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing } configService := &PaymentConfigService{ - entClient: client, + entClient: client, settingRepo: &paymentConfigSettingRepoStub{values: map[string]string{ - SettingPaymentVisibleMethodWxpayEnabled: "true", - SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + SettingPaymentVisibleMethodWxpayEnabled: "true", + SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat, + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-mp-app", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", }}, encryptionKey: []byte(jsapiTestEncryptionKey), } @@ -77,9 +84,6 @@ func TestSelectCreateOrderInstancePrefersJSAPICompatibleWxpayInstance(t *testing configService: configService, } - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx-mp-app") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret") - sel, err := svc.selectCreateOrderInstance(ctx, CreateOrderRequest{ PaymentType: payment.TypeWxpay, OpenID: "openid-123", diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go index 0daa8213..16757323 100644 --- a/backend/internal/service/payment_order_result_test.go +++ b/backend/internal/service/payment_order_result_test.go @@ -91,10 +91,15 @@ func TestBuildCreateOrderResponseCopiesJSAPIPayload(t *testing.T) { } func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx123456") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret") - - svc := &PaymentService{} + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ Amount: 12.5, @@ -132,7 +137,7 @@ func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testing.T) { t.Parallel() - svc := &PaymentService{} + svc := newWeChatPaymentOAuthTestService(nil) resp, err := svc.maybeBuildWeChatOAuthRequiredResponse(context.Background(), CreateOrderRequest{ Amount: 12.5, @@ -155,10 +160,15 @@ func TestMaybeBuildWeChatOAuthRequiredResponseRequiresMPConfigInWeChat(t *testin } func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t *testing.T) { - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "wx123456") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "wechat-secret") - - svc := &PaymentService{} + svc := newWeChatPaymentOAuthTestService(map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx123456", + SettingKeyWeChatConnectAppSecret: "wechat-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }) resp, err := svc.maybeBuildWeChatOAuthRequiredResponseForSelection(context.Background(), CreateOrderRequest{ Amount: 12.5, @@ -175,3 +185,11 @@ func TestMaybeBuildWeChatOAuthRequiredResponseForSelectionSkipsEasyPayProvider(t t.Fatalf("expected nil response, got %+v", resp) } } + +func newWeChatPaymentOAuthTestService(values map[string]string) *PaymentService { + return &PaymentService{ + configService: &PaymentConfigService{ + settingRepo: &paymentConfigSettingRepoStub{values: values}, + }, + } +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 8c879d52..373c1aef 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,7 +9,6 @@ import ( "fmt" "log/slog" "net/url" - "os" "sort" "strconv" "strings" @@ -173,8 +172,43 @@ var ( const ( defaultAuthSourceBalance = 0 defaultAuthSourceConcurrency = 5 + defaultWeChatConnectMode = "open" + defaultWeChatConnectScopes = "snsapi_login" + defaultWeChatConnectFrontend = "/auth/wechat/callback" ) +func normalizeWeChatConnectModeSetting(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "mp": + return "mp" + default: + return "open" + } +} + +func defaultWeChatConnectScopeForMode(mode string) string { + if normalizeWeChatConnectModeSetting(mode) == "mp" { + return "snsapi_userinfo" + } + return defaultWeChatConnectScopes +} + +func normalizeWeChatConnectScopeSetting(raw, mode string) string { + switch normalizeWeChatConnectModeSetting(mode) { + case "mp": + switch strings.TrimSpace(raw) { + case "snsapi_base": + return "snsapi_base" + case "snsapi_userinfo": + return "snsapi_userinfo" + default: + return defaultWeChatConnectScopeForMode(mode) + } + default: + return defaultWeChatConnectScopes + } +} + // NewSettingService 创建系统设置服务实例 func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ @@ -240,6 +274,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, SettingKeyBackendModeEnabled, SettingPaymentEnabled, SettingKeyOIDCConnectEnabled, @@ -274,9 +315,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings if oidcProviderName == "" { oidcProviderName = "OIDC" } - weChatOpenEnabled := isWeChatOAuthOpenConfigured() - weChatMPEnabled := isWeChatOAuthMPConfigured() - weChatEnabled := weChatOpenEnabled || weChatMPEnabled + weChatEnabled, weChatOpenEnabled, weChatMPEnabled := s.weChatOAuthCapabilitiesFromSettings(settings) // Password reset requires email verification to be enabled emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" @@ -431,6 +470,56 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any }, nil } +func DefaultWeChatConnectScopesForMode(mode string) string { + return defaultWeChatConnectScopeForMode(mode) +} + +func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) { + cfg := WeChatConnectOAuthConfig{ + Enabled: settings[SettingKeyWeChatConnectEnabled] == "true", + AppID: strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]), + AppSecret: strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]), + Mode: normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]), + Scopes: normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], settings[SettingKeyWeChatConnectMode]), + RedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]), + FrontendRedirectURL: strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]), + } + if cfg.FrontendRedirectURL == "" { + cfg.FrontendRedirectURL = defaultWeChatConnectFrontend + } + + if !cfg.Enabled { + return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if cfg.AppID == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth app id not configured") + } + if cfg.AppSecret == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth app secret not configured") + } + if cfg.RedirectURL == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url not configured") + } + if cfg.FrontendRedirectURL == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url not configured") + } + if err := config.ValidateAbsoluteHTTPURL(cfg.RedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid") + } + return cfg, nil +} + +func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool) { + cfg, err := s.parseWeChatConnectOAuthConfig(settings) + if err != nil { + return false, false, false + } + return true, cfg.Mode == "open", cfg.Mode == "mp" +} + // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON // array string, returning only items with visibility != "admin". func filterUserVisibleMenuItems(raw string) json.RawMessage { @@ -467,20 +556,6 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage { return result } -func isWeChatOAuthConfigured() bool { - return isWeChatOAuthOpenConfigured() || isWeChatOAuthMPConfigured() -} - -func isWeChatOAuthOpenConfigured() bool { - return strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" && - strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != "" -} - -func isWeChatOAuthMPConfigured() bool { - return strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" && - strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != "" -} - // safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]". func safeRawJSONArray(raw string) json.RawMessage { raw = strings.TrimSpace(raw) @@ -625,6 +700,15 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting } settings.PaymentVisibleMethodAlipaySource = alipaySource settings.PaymentVisibleMethodWxpaySource = wxpaySource + settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID) + settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret) + settings.WeChatConnectMode = normalizeWeChatConnectModeSetting(settings.WeChatConnectMode) + settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode) + settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL) + settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL) + if settings.WeChatConnectFrontendRedirectURL == "" { + settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } updates := make(map[string]string) @@ -694,6 +778,17 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret } + // WeChat Connect OAuth 登录 + updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled) + updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID + updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode + updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes + updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL + updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL + if settings.WeChatConnectAppSecret != "" { + updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -1200,6 +1295,10 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyTablePageSizeOptions: "[10,20,50,100]", SettingKeyCustomMenuItems: "[]", SettingKeyCustomEndpoints: "[]", + SettingKeyWeChatConnectEnabled: "false", + SettingKeyWeChatConnectMode: "open", + SettingKeyWeChatConnectScopes: "snsapi_login", + SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, SettingKeyOIDCConnectEnabled: "false", SettingKeyOIDCConnectProviderName: "OIDC", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), @@ -1491,6 +1590,19 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" + // WeChat Connect 设置:完全以 DB 系统设置为准。 + result.WeChatConnectEnabled = settings[SettingKeyWeChatConnectEnabled] == "true" + result.WeChatConnectAppID = strings.TrimSpace(settings[SettingKeyWeChatConnectAppID]) + result.WeChatConnectAppSecret = strings.TrimSpace(settings[SettingKeyWeChatConnectAppSecret]) + result.WeChatConnectAppSecretConfigured = result.WeChatConnectAppSecret != "" + result.WeChatConnectMode = normalizeWeChatConnectModeSetting(settings[SettingKeyWeChatConnectMode]) + result.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings[SettingKeyWeChatConnectScopes], settings[SettingKeyWeChatConnectMode]) + result.WeChatConnectRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectRedirectURL]) + result.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings[SettingKeyWeChatConnectFrontendRedirectURL]) + if result.WeChatConnectFrontendRedirectURL == "" { + result.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -1972,6 +2084,26 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return effective, nil } +// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。 +// +// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。 +func (s *SettingService) GetWeChatConnectOAuthConfig(ctx context.Context) (WeChatConnectOAuthConfig, error) { + keys := []string{ + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return WeChatConnectOAuthConfig{}, fmt.Errorf("get wechat connect settings: %w", err) + } + return s.parseWeChatConnectOAuthConfig(settings) +} + // GetOverloadCooldownSettings 获取529过载冷却配置 func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) { value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings) diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 4cfa9f0c..ffc069dc 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -92,16 +92,21 @@ func TestSettingService_GetPublicSettings_ExposesForceEmailOnThirdPartySignup(t } func TestSettingService_GetPublicSettings_ExposesWeChatOAuthModeCapabilities(t *testing.T) { - t.Setenv("WECHAT_OAUTH_OPEN_APP_ID", "wx-open-app") - t.Setenv("WECHAT_OAUTH_OPEN_APP_SECRET", "wx-open-secret") - t.Setenv("WECHAT_OAUTH_MP_APP_ID", "") - t.Setenv("WECHAT_OAUTH_MP_APP_SECRET", "") - - svc := NewSettingService(&settingPublicRepoStub{}, &config.Config{}) + svc := NewSettingService(&settingPublicRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-mp-app", + SettingKeyWeChatConnectAppSecret: "wx-mp-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + }, &config.Config{}) settings, err := svc.GetPublicSettings(context.Background()) require.NoError(t, err) require.True(t, settings.WeChatOAuthEnabled) - require.True(t, settings.WeChatOAuthOpenEnabled) - require.False(t, settings.WeChatOAuthMPEnabled) + require.False(t, settings.WeChatOAuthOpenEnabled) + require.True(t, settings.WeChatOAuthMPEnabled) } diff --git a/backend/internal/service/setting_service_wechat_config_test.go b/backend/internal/service/setting_service_wechat_config_test.go new file mode 100644 index 00000000..2cb312cc --- /dev/null +++ b/backend/internal/service/setting_service_wechat_config_test.go @@ -0,0 +1,77 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type settingWeChatRepoStub struct { + values map[string]string +} + +func (s *settingWeChatRepoStub) Get(context.Context, string) (*Setting, error) { + panic("unexpected Get call") +} + +func (s *settingWeChatRepoStub) GetValue(_ context.Context, key string) (string, error) { + if value, ok := s.values[key]; ok { + return value, nil + } + return "", ErrSettingNotFound +} + +func (s *settingWeChatRepoStub) Set(context.Context, string, string) error { + panic("unexpected Set call") +} + +func (s *settingWeChatRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingWeChatRepoStub) SetMultiple(context.Context, map[string]string) error { + panic("unexpected SetMultiple call") +} + +func (s *settingWeChatRepoStub) GetAll(context.Context) (map[string]string, error) { + panic("unexpected GetAll call") +} + +func (s *settingWeChatRepoStub) Delete(context.Context, string) error { + panic("unexpected Delete call") +} + +func TestSettingService_GetWeChatConnectOAuthConfig_UsesDatabaseOverrides(t *testing.T) { + repo := &settingWeChatRepoStub{ + values: map[string]string{ + SettingKeyWeChatConnectEnabled: "true", + SettingKeyWeChatConnectAppID: "wx-db-app", + SettingKeyWeChatConnectAppSecret: "wx-db-secret", + SettingKeyWeChatConnectMode: "mp", + SettingKeyWeChatConnectScopes: "snsapi_base", + SettingKeyWeChatConnectRedirectURL: "https://api.example.com/api/v1/auth/oauth/wechat/callback", + SettingKeyWeChatConnectFrontendRedirectURL: "/auth/wechat/callback", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + got, err := svc.GetWeChatConnectOAuthConfig(context.Background()) + require.NoError(t, err) + require.True(t, got.Enabled) + require.Equal(t, "wx-db-app", got.AppID) + require.Equal(t, "wx-db-secret", got.AppSecret) + require.Equal(t, "mp", got.Mode) + require.Equal(t, "snsapi_base", got.Scopes) + require.Equal(t, "https://api.example.com/api/v1/auth/oauth/wechat/callback", got.RedirectURL) + require.Equal(t, "/auth/wechat/callback", got.FrontendRedirectURL) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 1b859dbd..41229bfb 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -31,6 +31,16 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool LinuxDoConnectRedirectURL string + // WeChat Connect OAuth 登录 + WeChatConnectEnabled bool + WeChatConnectAppID string + WeChatConnectAppSecret string + WeChatConnectAppSecretConfigured bool + WeChatConnectMode string + WeChatConnectScopes string + WeChatConnectRedirectURL string + WeChatConnectFrontendRedirectURL string + // Generic OIDC OAuth 登录 OIDCConnectEnabled bool OIDCConnectProviderName string @@ -177,6 +187,16 @@ type PublicSettings struct { BalanceLowNotifyRechargeURL string } +type WeChatConnectOAuthConfig struct { + Enabled bool + AppID string + AppSecret string + Mode string + Scopes string + RedirectURL string + FrontendRedirectURL string +} + // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/frontend/src/api/__tests__/settings.wechatConnect.spec.ts b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts new file mode 100644 index 00000000..eccb7214 --- /dev/null +++ b/frontend/src/api/__tests__/settings.wechatConnect.spec.ts @@ -0,0 +1,21 @@ +import { describe, expect, it } from "vitest"; + +import { + defaultWeChatConnectScopesForMode, + normalizeWeChatConnectMode, +} from "@/api/admin/settings"; + +describe("admin settings wechat connect helpers", () => { + it("normalizes legacy or noisy mode values to the backend contract", () => { + expect(normalizeWeChatConnectMode("OPEN")).toBe("open"); + expect(normalizeWeChatConnectMode(" open_platform ")).toBe("open"); + expect(normalizeWeChatConnectMode("mp")).toBe("mp"); + expect(normalizeWeChatConnectMode("official_account")).toBe("mp"); + expect(normalizeWeChatConnectMode("unknown")).toBe("open"); + }); + + it("maps each mode to the backend default scopes", () => { + expect(defaultWeChatConnectScopesForMode("open")).toBe("snsapi_login"); + expect(defaultWeChatConnectScopesForMode("mp")).toBe("snsapi_userinfo"); + }); +}); diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 235bda7b..ed78a1bc 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -3,155 +3,240 @@ * Handles system settings management for administrators */ -import { apiClient } from '../client' -import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from '@/types' +import { apiClient } from "../client"; +import type { CustomMenuItem, CustomEndpoint, NotifyEmailEntry } from "@/types"; export interface DefaultSubscriptionSetting { - group_id: number - validity_days: number + group_id: number; + validity_days: number; } -export type AuthSourceType = 'email' | 'linuxdo' | 'oidc' | 'wechat' +export type AuthSourceType = "email" | "linuxdo" | "oidc" | "wechat"; export interface AuthSourceDefaultsValue { - balance: number - concurrency: number - subscriptions: DefaultSubscriptionSetting[] - grant_on_signup: boolean - grant_on_first_bind: boolean -} - -export type AuthSourceDefaultsState = Record -export type PaymentVisibleMethod = 'alipay' | 'wxpay' + balance: number; + concurrency: number; + subscriptions: DefaultSubscriptionSetting[]; + grant_on_signup: boolean; + grant_on_first_bind: boolean; +} + +export type AuthSourceDefaultsState = Record< + AuthSourceType, + AuthSourceDefaultsValue +>; +export type PaymentVisibleMethod = "alipay" | "wxpay"; export type PaymentVisibleMethodSource = - | '' - | 'official_alipay' - | 'easypay_alipay' - | 'official_wxpay' - | 'easypay_wxpay' + | "" + | "official_alipay" + | "easypay_alipay" + | "official_wxpay" + | "easypay_wxpay"; +export type WeChatConnectMode = "open" | "mp"; export interface PaymentVisibleMethodSourceOption { - value: PaymentVisibleMethodSource - labelZh: string - labelEn: string + value: PaymentVisibleMethodSource; + labelZh: string; + labelEn: string; +} + +export interface WeChatConnectModeOption { + value: WeChatConnectMode; + labelZh: string; + labelEn: string; } -const AUTH_SOURCE_TYPES: AuthSourceType[] = ['email', 'linuxdo', 'oidc', 'wechat'] -const AUTH_SOURCE_DEFAULT_BALANCE = 0 -const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5 +const AUTH_SOURCE_TYPES: AuthSourceType[] = [ + "email", + "linuxdo", + "oidc", + "wechat", +]; +const AUTH_SOURCE_DEFAULT_BALANCE = 0; +const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5; const PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS: Record< PaymentVisibleMethod, PaymentVisibleMethodSourceOption[] > = { alipay: [ - { value: '', labelZh: '未配置', labelEn: 'Not configured' }, - { value: 'official_alipay', labelZh: '支付宝官方', labelEn: 'Official Alipay' }, - { value: 'easypay_alipay', labelZh: '易支付支付宝', labelEn: 'EasyPay Alipay' }, + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_alipay", + labelZh: "支付宝官方", + labelEn: "Official Alipay", + }, + { + value: "easypay_alipay", + labelZh: "易支付支付宝", + labelEn: "EasyPay Alipay", + }, ], wxpay: [ - { value: '', labelZh: '未配置', labelEn: 'Not configured' }, - { value: 'official_wxpay', labelZh: '微信官方', labelEn: 'Official WeChat Pay' }, - { value: 'easypay_wxpay', labelZh: '易支付微信', labelEn: 'EasyPay WeChat Pay' }, + { value: "", labelZh: "未配置", labelEn: "Not configured" }, + { + value: "official_wxpay", + labelZh: "微信官方", + labelEn: "Official WeChat Pay", + }, + { + value: "easypay_wxpay", + labelZh: "易支付微信", + labelEn: "EasyPay WeChat Pay", + }, ], -} +}; const PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES: Record< PaymentVisibleMethod, Record > = { alipay: { - official_alipay: 'official_alipay', - alipay: 'official_alipay', - alipay_direct: 'official_alipay', - official: 'official_alipay', - easypay_alipay: 'easypay_alipay', - easypay: 'easypay_alipay', + official_alipay: "official_alipay", + alipay: "official_alipay", + alipay_direct: "official_alipay", + official: "official_alipay", + easypay_alipay: "easypay_alipay", + easypay: "easypay_alipay", }, wxpay: { - official_wxpay: 'official_wxpay', - wxpay: 'official_wxpay', - wxpay_direct: 'official_wxpay', - wechat: 'official_wxpay', - official: 'official_wxpay', - easypay_wxpay: 'easypay_wxpay', - easypay: 'easypay_wxpay', + official_wxpay: "official_wxpay", + wxpay: "official_wxpay", + wxpay_direct: "official_wxpay", + wechat: "official_wxpay", + official: "official_wxpay", + easypay_wxpay: "easypay_wxpay", + easypay: "easypay_wxpay", }, -} +}; +const WECHAT_CONNECT_MODE_OPTIONS: WeChatConnectModeOption[] = [ + { value: "open", labelZh: "微信开放平台", labelEn: "WeChat Open Platform" }, + { + value: "mp", + labelZh: "微信公众号 / 小程序", + labelEn: "WeChat Official Account / Mini Program", + }, +]; +const WECHAT_CONNECT_MODE_ALIASES: Record = { + open: "open", + open_platform: "open", + official: "open", + wx_open: "open", + mp: "mp", + official_account: "mp", + wechat_mp: "mp", + mini_program: "mp", +}; export function normalizeDefaultSubscriptionSettings( - subscriptions: DefaultSubscriptionSetting[] | null | undefined + subscriptions: DefaultSubscriptionSetting[] | null | undefined, ): DefaultSubscriptionSetting[] { - if (!Array.isArray(subscriptions)) return [] + if (!Array.isArray(subscriptions)) return []; return subscriptions .filter((item) => item.group_id > 0 && item.validity_days > 0) .map((item) => ({ group_id: Math.floor(item.group_id), - validity_days: Math.min(36500, Math.max(1, Math.floor(item.validity_days))) - })) + validity_days: Math.min( + 36500, + Math.max(1, Math.floor(item.validity_days)), + ), + })); } export function buildAuthSourceDefaultsState( - settings: Partial + settings: Partial, ): AuthSourceDefaultsState { - const raw = settings as Record + const raw = settings as Record; return AUTH_SOURCE_TYPES.reduce((acc, source) => { - const subscriptions = raw[`auth_source_default_${source}_subscriptions`] + const subscriptions = raw[`auth_source_default_${source}_subscriptions`]; acc[source] = { - balance: Number(raw[`auth_source_default_${source}_balance`] ?? AUTH_SOURCE_DEFAULT_BALANCE), + balance: Number( + raw[`auth_source_default_${source}_balance`] ?? + AUTH_SOURCE_DEFAULT_BALANCE, + ), concurrency: Math.max( 1, - Number(raw[`auth_source_default_${source}_concurrency`] ?? AUTH_SOURCE_DEFAULT_CONCURRENCY) + Number( + raw[`auth_source_default_${source}_concurrency`] ?? + AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), ), subscriptions: normalizeDefaultSubscriptionSettings( - Array.isArray(subscriptions) ? (subscriptions as DefaultSubscriptionSetting[]) : [] + Array.isArray(subscriptions) + ? (subscriptions as DefaultSubscriptionSetting[]) + : [], ), - grant_on_signup: raw[`auth_source_default_${source}_grant_on_signup`] !== false, - grant_on_first_bind: raw[`auth_source_default_${source}_grant_on_first_bind`] === true, - } - return acc - }, {} as AuthSourceDefaultsState) + grant_on_signup: + raw[`auth_source_default_${source}_grant_on_signup`] !== false, + grant_on_first_bind: + raw[`auth_source_default_${source}_grant_on_first_bind`] === true, + }; + return acc; + }, {} as AuthSourceDefaultsState); } export function appendAuthSourceDefaultsToUpdateRequest( payload: UpdateSettingsRequest, - authSourceDefaults: AuthSourceDefaultsState + authSourceDefaults: AuthSourceDefaultsState, ): UpdateSettingsRequest { - const target = payload as Record + const target = payload as Record; for (const source of AUTH_SOURCE_TYPES) { - const current = authSourceDefaults[source] - target[`auth_source_default_${source}_balance`] = Number(current.balance) || 0 + const current = authSourceDefaults[source]; + target[`auth_source_default_${source}_balance`] = + Number(current.balance) || 0; target[`auth_source_default_${source}_concurrency`] = Math.max( 1, - Math.floor(Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY) - ) - target[`auth_source_default_${source}_subscriptions`] = normalizeDefaultSubscriptionSettings( - current.subscriptions - ) - target[`auth_source_default_${source}_grant_on_signup`] = current.grant_on_signup - target[`auth_source_default_${source}_grant_on_first_bind`] = current.grant_on_first_bind + Math.floor( + Number(current.concurrency) || AUTH_SOURCE_DEFAULT_CONCURRENCY, + ), + ); + target[`auth_source_default_${source}_subscriptions`] = + normalizeDefaultSubscriptionSettings(current.subscriptions); + target[`auth_source_default_${source}_grant_on_signup`] = + current.grant_on_signup; + target[`auth_source_default_${source}_grant_on_first_bind`] = + current.grant_on_first_bind; } - return payload + return payload; } export function getPaymentVisibleMethodSourceOptions( - method: PaymentVisibleMethod + method: PaymentVisibleMethod, ): PaymentVisibleMethodSourceOption[] { - return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method] + return PAYMENT_VISIBLE_METHOD_SOURCE_OPTIONS[method]; } export function normalizePaymentVisibleMethodSource( method: PaymentVisibleMethod, - source: unknown + source: unknown, ): PaymentVisibleMethodSource { - if (typeof source !== 'string') return '' + if (typeof source !== "string") return ""; - const normalized = source.trim().toLowerCase() - if (!normalized) return '' + const normalized = source.trim().toLowerCase(); + if (!normalized) return ""; - return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? '' + return PAYMENT_VISIBLE_METHOD_SOURCE_ALIASES[method][normalized] ?? ""; +} + +export function getWeChatConnectModeOptions(): WeChatConnectModeOption[] { + return WECHAT_CONNECT_MODE_OPTIONS; +} + +export function normalizeWeChatConnectMode(source: unknown): WeChatConnectMode { + if (typeof source !== "string") return "open"; + + const normalized = source.trim().toLowerCase(); + if (!normalized) return "open"; + + return WECHAT_CONNECT_MODE_ALIASES[normalized] ?? "open"; +} + +export function defaultWeChatConnectScopesForMode(mode: unknown): string { + return normalizeWeChatConnectMode(mode) === "mp" + ? "snsapi_userinfo" + : "snsapi_login"; } /** @@ -159,293 +244,309 @@ export function normalizePaymentVisibleMethodSource( */ export interface SystemSettings { // Registration settings - registration_enabled: boolean - email_verify_enabled: boolean - registration_email_suffix_whitelist: string[] - promo_code_enabled: boolean - password_reset_enabled: boolean - frontend_url: string - invitation_code_enabled: boolean - totp_enabled: boolean // TOTP 双因素认证 - totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置 + registration_enabled: boolean; + email_verify_enabled: boolean; + registration_email_suffix_whitelist: string[]; + promo_code_enabled: boolean; + password_reset_enabled: boolean; + frontend_url: string; + invitation_code_enabled: boolean; + totp_enabled: boolean; // TOTP 双因素认证 + totp_encryption_key_configured: boolean; // TOTP 加密密钥是否已配置 // Default settings - default_balance: number - default_concurrency: number - default_subscriptions: DefaultSubscriptionSetting[] - auth_source_default_email_balance?: number - auth_source_default_email_concurrency?: number - auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_email_grant_on_signup?: boolean - auth_source_default_email_grant_on_first_bind?: boolean - auth_source_default_linuxdo_balance?: number - auth_source_default_linuxdo_concurrency?: number - auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_linuxdo_grant_on_signup?: boolean - auth_source_default_linuxdo_grant_on_first_bind?: boolean - auth_source_default_oidc_balance?: number - auth_source_default_oidc_concurrency?: number - auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_oidc_grant_on_signup?: boolean - auth_source_default_oidc_grant_on_first_bind?: boolean - auth_source_default_wechat_balance?: number - auth_source_default_wechat_concurrency?: number - auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_wechat_grant_on_signup?: boolean - auth_source_default_wechat_grant_on_first_bind?: boolean - force_email_on_third_party_signup?: boolean + default_balance: number; + default_concurrency: number; + default_subscriptions: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; // OEM settings - site_name: string - site_logo: string - site_subtitle: string - api_base_url: string - contact_info: string - doc_url: string - home_content: string - hide_ccs_import_button: boolean - table_default_page_size: number - table_page_size_options: number[] - backend_mode_enabled: boolean - custom_menu_items: CustomMenuItem[] - custom_endpoints: CustomEndpoint[] + site_name: string; + site_logo: string; + site_subtitle: string; + api_base_url: string; + contact_info: string; + doc_url: string; + home_content: string; + hide_ccs_import_button: boolean; + table_default_page_size: number; + table_page_size_options: number[]; + backend_mode_enabled: boolean; + custom_menu_items: CustomMenuItem[]; + custom_endpoints: CustomEndpoint[]; // SMTP settings - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password_configured: boolean - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password_configured: boolean; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; // Cloudflare Turnstile settings - turnstile_enabled: boolean - turnstile_site_key: string - turnstile_secret_key_configured: boolean + turnstile_enabled: boolean; + turnstile_site_key: string; + turnstile_secret_key_configured: boolean; // LinuxDo Connect OAuth settings - linuxdo_connect_enabled: boolean - linuxdo_connect_client_id: string - linuxdo_connect_client_secret_configured: boolean - linuxdo_connect_redirect_url: string + linuxdo_connect_enabled: boolean; + linuxdo_connect_client_id: string; + linuxdo_connect_client_secret_configured: boolean; + linuxdo_connect_redirect_url: string; + + // WeChat Connect OAuth settings + wechat_connect_enabled: boolean; + wechat_connect_app_id: string; + wechat_connect_app_secret_configured: boolean; + wechat_connect_mode: string; + wechat_connect_scopes: string; + wechat_connect_redirect_url: string; + wechat_connect_frontend_redirect_url: string; // Generic OIDC OAuth settings - oidc_connect_enabled: boolean - oidc_connect_provider_name: string - oidc_connect_client_id: string - oidc_connect_client_secret_configured: boolean - oidc_connect_issuer_url: string - oidc_connect_discovery_url: string - oidc_connect_authorize_url: string - oidc_connect_token_url: string - oidc_connect_userinfo_url: string - oidc_connect_jwks_url: string - oidc_connect_scopes: string - oidc_connect_redirect_url: string - oidc_connect_frontend_redirect_url: string - oidc_connect_token_auth_method: string - oidc_connect_use_pkce: boolean - oidc_connect_validate_id_token: boolean - oidc_connect_allowed_signing_algs: string - oidc_connect_clock_skew_seconds: number - oidc_connect_require_email_verified: boolean - oidc_connect_userinfo_email_path: string - oidc_connect_userinfo_id_path: string - oidc_connect_userinfo_username_path: string + oidc_connect_enabled: boolean; + oidc_connect_provider_name: string; + oidc_connect_client_id: string; + oidc_connect_client_secret_configured: boolean; + oidc_connect_issuer_url: string; + oidc_connect_discovery_url: string; + oidc_connect_authorize_url: string; + oidc_connect_token_url: string; + oidc_connect_userinfo_url: string; + oidc_connect_jwks_url: string; + oidc_connect_scopes: string; + oidc_connect_redirect_url: string; + oidc_connect_frontend_redirect_url: string; + oidc_connect_token_auth_method: string; + oidc_connect_use_pkce: boolean; + oidc_connect_validate_id_token: boolean; + oidc_connect_allowed_signing_algs: string; + oidc_connect_clock_skew_seconds: number; + oidc_connect_require_email_verified: boolean; + oidc_connect_userinfo_email_path: string; + oidc_connect_userinfo_id_path: string; + oidc_connect_userinfo_username_path: string; // Model fallback configuration - enable_model_fallback: boolean - fallback_model_anthropic: string - fallback_model_openai: string - fallback_model_gemini: string - fallback_model_antigravity: string + enable_model_fallback: boolean; + fallback_model_anthropic: string; + fallback_model_openai: string; + fallback_model_gemini: string; + fallback_model_antigravity: string; // Identity patch configuration (Claude -> Gemini) - enable_identity_patch: boolean - identity_patch_prompt: string + enable_identity_patch: boolean; + identity_patch_prompt: string; // Ops Monitoring (vNext) - ops_monitoring_enabled: boolean - ops_realtime_monitoring_enabled: boolean - ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds: number + ops_monitoring_enabled: boolean; + ops_realtime_monitoring_enabled: boolean; + ops_query_mode_default: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds: number; // Claude Code version check - min_claude_code_version: string - max_claude_code_version: string + min_claude_code_version: string; + max_claude_code_version: string; // 分组隔离 - allow_ungrouped_key_scheduling: boolean + allow_ungrouped_key_scheduling: boolean; // Gateway forwarding behavior - enable_fingerprint_unification: boolean - enable_metadata_passthrough: boolean - enable_cch_signing: boolean - web_search_emulation_enabled?: boolean + enable_fingerprint_unification: boolean; + enable_metadata_passthrough: boolean; + enable_cch_signing: boolean; + web_search_emulation_enabled?: boolean; // Payment configuration - payment_enabled: boolean - payment_min_amount: number - payment_max_amount: number - payment_daily_limit: number - payment_order_timeout_minutes: number - payment_max_pending_orders: number - payment_enabled_types: string[] - payment_balance_disabled: boolean - payment_balance_recharge_multiplier: number - payment_recharge_fee_rate: number - payment_load_balance_strategy: string - payment_product_name_prefix: string - payment_product_name_suffix: string - payment_help_image_url: string - payment_help_text: string - payment_cancel_rate_limit_enabled: boolean - payment_cancel_rate_limit_max: number - payment_cancel_rate_limit_window: number - payment_cancel_rate_limit_unit: string - payment_cancel_rate_limit_window_mode: string - payment_visible_method_alipay_source?: string - payment_visible_method_wxpay_source?: string - payment_visible_method_alipay_enabled?: boolean - payment_visible_method_wxpay_enabled?: boolean - openai_advanced_scheduler_enabled?: boolean + payment_enabled: boolean; + payment_min_amount: number; + payment_max_amount: number; + payment_daily_limit: number; + payment_order_timeout_minutes: number; + payment_max_pending_orders: number; + payment_enabled_types: string[]; + payment_balance_disabled: boolean; + payment_balance_recharge_multiplier: number; + payment_recharge_fee_rate: number; + payment_load_balance_strategy: string; + payment_product_name_prefix: string; + payment_product_name_suffix: string; + payment_help_image_url: string; + payment_help_text: string; + payment_cancel_rate_limit_enabled: boolean; + payment_cancel_rate_limit_max: number; + payment_cancel_rate_limit_window: number; + payment_cancel_rate_limit_unit: string; + payment_cancel_rate_limit_window_mode: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled: boolean - balance_low_notify_threshold: number - balance_low_notify_recharge_url: string - account_quota_notify_enabled: boolean - account_quota_notify_emails: NotifyEmailEntry[] + balance_low_notify_enabled: boolean; + balance_low_notify_threshold: number; + balance_low_notify_recharge_url: string; + account_quota_notify_enabled: boolean; + account_quota_notify_emails: NotifyEmailEntry[]; } export interface UpdateSettingsRequest { - registration_enabled?: boolean - email_verify_enabled?: boolean - registration_email_suffix_whitelist?: string[] - promo_code_enabled?: boolean - password_reset_enabled?: boolean - frontend_url?: string - invitation_code_enabled?: boolean - totp_enabled?: boolean // TOTP 双因素认证 - default_balance?: number - default_concurrency?: number - default_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_email_balance?: number - auth_source_default_email_concurrency?: number - auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_email_grant_on_signup?: boolean - auth_source_default_email_grant_on_first_bind?: boolean - auth_source_default_linuxdo_balance?: number - auth_source_default_linuxdo_concurrency?: number - auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_linuxdo_grant_on_signup?: boolean - auth_source_default_linuxdo_grant_on_first_bind?: boolean - auth_source_default_oidc_balance?: number - auth_source_default_oidc_concurrency?: number - auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_oidc_grant_on_signup?: boolean - auth_source_default_oidc_grant_on_first_bind?: boolean - auth_source_default_wechat_balance?: number - auth_source_default_wechat_concurrency?: number - auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[] - auth_source_default_wechat_grant_on_signup?: boolean - auth_source_default_wechat_grant_on_first_bind?: boolean - force_email_on_third_party_signup?: boolean - site_name?: string - site_logo?: string - site_subtitle?: string - api_base_url?: string - contact_info?: string - doc_url?: string - home_content?: string - hide_ccs_import_button?: boolean - table_default_page_size?: number - table_page_size_options?: number[] - backend_mode_enabled?: boolean - custom_menu_items?: CustomMenuItem[] - custom_endpoints?: CustomEndpoint[] - smtp_host?: string - smtp_port?: number - smtp_username?: string - smtp_password?: string - smtp_from_email?: string - smtp_from_name?: string - smtp_use_tls?: boolean - turnstile_enabled?: boolean - turnstile_site_key?: string - turnstile_secret_key?: string - linuxdo_connect_enabled?: boolean - linuxdo_connect_client_id?: string - linuxdo_connect_client_secret?: string - linuxdo_connect_redirect_url?: string - oidc_connect_enabled?: boolean - oidc_connect_provider_name?: string - oidc_connect_client_id?: string - oidc_connect_client_secret?: string - oidc_connect_issuer_url?: string - oidc_connect_discovery_url?: string - oidc_connect_authorize_url?: string - oidc_connect_token_url?: string - oidc_connect_userinfo_url?: string - oidc_connect_jwks_url?: string - oidc_connect_scopes?: string - oidc_connect_redirect_url?: string - oidc_connect_frontend_redirect_url?: string - oidc_connect_token_auth_method?: string - oidc_connect_use_pkce?: boolean - oidc_connect_validate_id_token?: boolean - oidc_connect_allowed_signing_algs?: string - oidc_connect_clock_skew_seconds?: number - oidc_connect_require_email_verified?: boolean - oidc_connect_userinfo_email_path?: string - oidc_connect_userinfo_id_path?: string - oidc_connect_userinfo_username_path?: string - enable_model_fallback?: boolean - fallback_model_anthropic?: string - fallback_model_openai?: string - fallback_model_gemini?: string - fallback_model_antigravity?: string - enable_identity_patch?: boolean - identity_patch_prompt?: string - ops_monitoring_enabled?: boolean - ops_realtime_monitoring_enabled?: boolean - ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string - ops_metrics_interval_seconds?: number - min_claude_code_version?: string - max_claude_code_version?: string - allow_ungrouped_key_scheduling?: boolean - enable_fingerprint_unification?: boolean - enable_metadata_passthrough?: boolean - enable_cch_signing?: boolean + registration_enabled?: boolean; + email_verify_enabled?: boolean; + registration_email_suffix_whitelist?: string[]; + promo_code_enabled?: boolean; + password_reset_enabled?: boolean; + frontend_url?: string; + invitation_code_enabled?: boolean; + totp_enabled?: boolean; // TOTP 双因素认证 + default_balance?: number; + default_concurrency?: number; + default_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_balance?: number; + auth_source_default_email_concurrency?: number; + auth_source_default_email_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_email_grant_on_signup?: boolean; + auth_source_default_email_grant_on_first_bind?: boolean; + auth_source_default_linuxdo_balance?: number; + auth_source_default_linuxdo_concurrency?: number; + auth_source_default_linuxdo_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_linuxdo_grant_on_signup?: boolean; + auth_source_default_linuxdo_grant_on_first_bind?: boolean; + auth_source_default_oidc_balance?: number; + auth_source_default_oidc_concurrency?: number; + auth_source_default_oidc_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_oidc_grant_on_signup?: boolean; + auth_source_default_oidc_grant_on_first_bind?: boolean; + auth_source_default_wechat_balance?: number; + auth_source_default_wechat_concurrency?: number; + auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_wechat_grant_on_signup?: boolean; + auth_source_default_wechat_grant_on_first_bind?: boolean; + force_email_on_third_party_signup?: boolean; + site_name?: string; + site_logo?: string; + site_subtitle?: string; + api_base_url?: string; + contact_info?: string; + doc_url?: string; + home_content?: string; + hide_ccs_import_button?: boolean; + table_default_page_size?: number; + table_page_size_options?: number[]; + backend_mode_enabled?: boolean; + custom_menu_items?: CustomMenuItem[]; + custom_endpoints?: CustomEndpoint[]; + smtp_host?: string; + smtp_port?: number; + smtp_username?: string; + smtp_password?: string; + smtp_from_email?: string; + smtp_from_name?: string; + smtp_use_tls?: boolean; + turnstile_enabled?: boolean; + turnstile_site_key?: string; + turnstile_secret_key?: string; + linuxdo_connect_enabled?: boolean; + linuxdo_connect_client_id?: string; + linuxdo_connect_client_secret?: string; + linuxdo_connect_redirect_url?: string; + wechat_connect_enabled?: boolean; + wechat_connect_app_id?: string; + wechat_connect_app_secret?: string; + wechat_connect_mode?: string; + wechat_connect_scopes?: string; + wechat_connect_redirect_url?: string; + wechat_connect_frontend_redirect_url?: string; + oidc_connect_enabled?: boolean; + oidc_connect_provider_name?: string; + oidc_connect_client_id?: string; + oidc_connect_client_secret?: string; + oidc_connect_issuer_url?: string; + oidc_connect_discovery_url?: string; + oidc_connect_authorize_url?: string; + oidc_connect_token_url?: string; + oidc_connect_userinfo_url?: string; + oidc_connect_jwks_url?: string; + oidc_connect_scopes?: string; + oidc_connect_redirect_url?: string; + oidc_connect_frontend_redirect_url?: string; + oidc_connect_token_auth_method?: string; + oidc_connect_use_pkce?: boolean; + oidc_connect_validate_id_token?: boolean; + oidc_connect_allowed_signing_algs?: string; + oidc_connect_clock_skew_seconds?: number; + oidc_connect_require_email_verified?: boolean; + oidc_connect_userinfo_email_path?: string; + oidc_connect_userinfo_id_path?: string; + oidc_connect_userinfo_username_path?: string; + enable_model_fallback?: boolean; + fallback_model_anthropic?: string; + fallback_model_openai?: string; + fallback_model_gemini?: string; + fallback_model_antigravity?: string; + enable_identity_patch?: boolean; + identity_patch_prompt?: string; + ops_monitoring_enabled?: boolean; + ops_realtime_monitoring_enabled?: boolean; + ops_query_mode_default?: "auto" | "raw" | "preagg" | string; + ops_metrics_interval_seconds?: number; + min_claude_code_version?: string; + max_claude_code_version?: string; + allow_ungrouped_key_scheduling?: boolean; + enable_fingerprint_unification?: boolean; + enable_metadata_passthrough?: boolean; + enable_cch_signing?: boolean; // Payment configuration - payment_enabled?: boolean - payment_min_amount?: number - payment_max_amount?: number - payment_daily_limit?: number - payment_order_timeout_minutes?: number - payment_max_pending_orders?: number - payment_enabled_types?: string[] - payment_balance_disabled?: boolean - payment_balance_recharge_multiplier?: number - payment_recharge_fee_rate?: number - payment_load_balance_strategy?: string - payment_product_name_prefix?: string - payment_product_name_suffix?: string - payment_help_image_url?: string - payment_help_text?: string - payment_cancel_rate_limit_enabled?: boolean - payment_cancel_rate_limit_max?: number - payment_cancel_rate_limit_window?: number - payment_cancel_rate_limit_unit?: string - payment_cancel_rate_limit_window_mode?: string - payment_visible_method_alipay_source?: string - payment_visible_method_wxpay_source?: string - payment_visible_method_alipay_enabled?: boolean - payment_visible_method_wxpay_enabled?: boolean - openai_advanced_scheduler_enabled?: boolean + payment_enabled?: boolean; + payment_min_amount?: number; + payment_max_amount?: number; + payment_daily_limit?: number; + payment_order_timeout_minutes?: number; + payment_max_pending_orders?: number; + payment_enabled_types?: string[]; + payment_balance_disabled?: boolean; + payment_balance_recharge_multiplier?: number; + payment_recharge_fee_rate?: number; + payment_load_balance_strategy?: string; + payment_product_name_prefix?: string; + payment_product_name_suffix?: string; + payment_help_image_url?: string; + payment_help_text?: string; + payment_cancel_rate_limit_enabled?: boolean; + payment_cancel_rate_limit_max?: number; + payment_cancel_rate_limit_window?: number; + payment_cancel_rate_limit_unit?: string; + payment_cancel_rate_limit_window_mode?: string; + payment_visible_method_alipay_source?: string; + payment_visible_method_wxpay_source?: string; + payment_visible_method_alipay_enabled?: boolean; + payment_visible_method_wxpay_enabled?: boolean; + openai_advanced_scheduler_enabled?: boolean; // Balance & quota notification - balance_low_notify_enabled?: boolean - balance_low_notify_threshold?: number - balance_low_notify_recharge_url?: string - account_quota_notify_enabled?: boolean - account_quota_notify_emails?: NotifyEmailEntry[] + balance_low_notify_enabled?: boolean; + balance_low_notify_threshold?: number; + balance_low_notify_recharge_url?: string; + account_quota_notify_enabled?: boolean; + account_quota_notify_emails?: NotifyEmailEntry[]; } /** @@ -453,8 +554,8 @@ export interface UpdateSettingsRequest { * @returns System settings */ export async function getSettings(): Promise { - const { data } = await apiClient.get('/admin/settings') - return data + const { data } = await apiClient.get("/admin/settings"); + return data; } /** @@ -462,20 +563,25 @@ export async function getSettings(): Promise { * @param settings - Partial settings to update * @returns Updated settings */ -export async function updateSettings(settings: UpdateSettingsRequest): Promise { - const { data } = await apiClient.put('/admin/settings', settings) - return data +export async function updateSettings( + settings: UpdateSettingsRequest, +): Promise { + const { data } = await apiClient.put( + "/admin/settings", + settings, + ); + return data; } /** * Test SMTP connection request */ export interface TestSmtpRequest { - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_use_tls: boolean + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_use_tls: boolean; } /** @@ -483,23 +589,28 @@ export interface TestSmtpRequest { * @param config - SMTP configuration to test * @returns Test result message */ -export async function testSmtpConnection(config: TestSmtpRequest): Promise<{ message: string }> { - const { data } = await apiClient.post<{ message: string }>('/admin/settings/test-smtp', config) - return data +export async function testSmtpConnection( + config: TestSmtpRequest, +): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>( + "/admin/settings/test-smtp", + config, + ); + return data; } /** * Send test email request */ export interface SendTestEmailRequest { - email: string - smtp_host: string - smtp_port: number - smtp_username: string - smtp_password: string - smtp_from_email: string - smtp_from_name: string - smtp_use_tls: boolean + email: string; + smtp_host: string; + smtp_port: number; + smtp_username: string; + smtp_password: string; + smtp_from_email: string; + smtp_from_name: string; + smtp_use_tls: boolean; } /** @@ -507,20 +618,22 @@ export interface SendTestEmailRequest { * @param request - Email address and SMTP config * @returns Test result message */ -export async function sendTestEmail(request: SendTestEmailRequest): Promise<{ message: string }> { +export async function sendTestEmail( + request: SendTestEmailRequest, +): Promise<{ message: string }> { const { data } = await apiClient.post<{ message: string }>( - '/admin/settings/send-test-email', - request - ) - return data + "/admin/settings/send-test-email", + request, + ); + return data; } /** * Admin API Key status response */ export interface AdminApiKeyStatus { - exists: boolean - masked_key: string + exists: boolean; + masked_key: string; } /** @@ -528,8 +641,10 @@ export interface AdminApiKeyStatus { * @returns Status indicating if key exists and masked version */ export async function getAdminApiKey(): Promise { - const { data } = await apiClient.get('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.get( + "/admin/settings/admin-api-key", + ); + return data; } /** @@ -537,8 +652,10 @@ export async function getAdminApiKey(): Promise { * @returns The new full API key (only shown once) */ export async function regenerateAdminApiKey(): Promise<{ key: string }> { - const { data } = await apiClient.post<{ key: string }>('/admin/settings/admin-api-key/regenerate') - return data + const { data } = await apiClient.post<{ key: string }>( + "/admin/settings/admin-api-key/regenerate", + ); + return data; } /** @@ -546,8 +663,10 @@ export async function regenerateAdminApiKey(): Promise<{ key: string }> { * @returns Success message */ export async function deleteAdminApiKey(): Promise<{ message: string }> { - const { data } = await apiClient.delete<{ message: string }>('/admin/settings/admin-api-key') - return data + const { data } = await apiClient.delete<{ message: string }>( + "/admin/settings/admin-api-key", + ); + return data; } // ==================== Overload Cooldown Settings ==================== @@ -556,23 +675,25 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> { * Overload cooldown settings interface (529 handling) */ export interface OverloadCooldownSettings { - enabled: boolean - cooldown_minutes: number + enabled: boolean; + cooldown_minutes: number; } export async function getOverloadCooldownSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/overload-cooldown') - return data + const { data } = await apiClient.get( + "/admin/settings/overload-cooldown", + ); + return data; } export async function updateOverloadCooldownSettings( - settings: OverloadCooldownSettings + settings: OverloadCooldownSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/overload-cooldown', - settings - ) - return data + "/admin/settings/overload-cooldown", + settings, + ); + return data; } // ==================== Stream Timeout Settings ==================== @@ -581,11 +702,11 @@ export async function updateOverloadCooldownSettings( * Stream timeout settings interface */ export interface StreamTimeoutSettings { - enabled: boolean - action: 'temp_unsched' | 'error' | 'none' - temp_unsched_minutes: number - threshold_count: number - threshold_window_minutes: number + enabled: boolean; + action: "temp_unsched" | "error" | "none"; + temp_unsched_minutes: number; + threshold_count: number; + threshold_window_minutes: number; } /** @@ -593,8 +714,10 @@ export interface StreamTimeoutSettings { * @returns Stream timeout settings */ export async function getStreamTimeoutSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/stream-timeout') - return data + const { data } = await apiClient.get( + "/admin/settings/stream-timeout", + ); + return data; } /** @@ -603,13 +726,13 @@ export async function getStreamTimeoutSettings(): Promise * @returns Updated settings */ export async function updateStreamTimeoutSettings( - settings: StreamTimeoutSettings + settings: StreamTimeoutSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/stream-timeout', - settings - ) - return data + "/admin/settings/stream-timeout", + settings, + ); + return data; } // ==================== Rectifier Settings ==================== @@ -618,11 +741,11 @@ export async function updateStreamTimeoutSettings( * Rectifier settings interface */ export interface RectifierSettings { - enabled: boolean - thinking_signature_enabled: boolean - thinking_budget_enabled: boolean - apikey_signature_enabled: boolean - apikey_signature_patterns: string[] + enabled: boolean; + thinking_signature_enabled: boolean; + thinking_budget_enabled: boolean; + apikey_signature_enabled: boolean; + apikey_signature_patterns: string[]; } /** @@ -630,8 +753,10 @@ export interface RectifierSettings { * @returns Rectifier settings */ export async function getRectifierSettings(): Promise { - const { data } = await apiClient.get('/admin/settings/rectifier') - return data + const { data } = await apiClient.get( + "/admin/settings/rectifier", + ); + return data; } /** @@ -640,13 +765,13 @@ export async function getRectifierSettings(): Promise { * @returns Updated settings */ export async function updateRectifierSettings( - settings: RectifierSettings + settings: RectifierSettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/rectifier', - settings - ) - return data + "/admin/settings/rectifier", + settings, + ); + return data; } // ==================== Beta Policy Settings ==================== @@ -655,20 +780,20 @@ export async function updateRectifierSettings( * Beta policy rule interface */ export interface BetaPolicyRule { - beta_token: string - action: 'pass' | 'filter' | 'block' - scope: 'all' | 'oauth' | 'apikey' | 'bedrock' - error_message?: string - model_whitelist?: string[] - fallback_action?: 'pass' | 'filter' | 'block' - fallback_error_message?: string + beta_token: string; + action: "pass" | "filter" | "block"; + scope: "all" | "oauth" | "apikey" | "bedrock"; + error_message?: string; + model_whitelist?: string[]; + fallback_action?: "pass" | "filter" | "block"; + fallback_error_message?: string; } /** * Beta policy settings interface */ export interface BetaPolicySettings { - rules: BetaPolicyRule[] + rules: BetaPolicyRule[]; } /** @@ -676,8 +801,10 @@ export interface BetaPolicySettings { * @returns Beta policy settings */ export async function getBetaPolicySettings(): Promise { - const { data } = await apiClient.get('/admin/settings/beta-policy') - return data + const { data } = await apiClient.get( + "/admin/settings/beta-policy", + ); + return data; } /** @@ -686,70 +813,73 @@ export async function getBetaPolicySettings(): Promise { * @returns Updated settings */ export async function updateBetaPolicySettings( - settings: BetaPolicySettings + settings: BetaPolicySettings, ): Promise { const { data } = await apiClient.put( - '/admin/settings/beta-policy', - settings - ) - return data + "/admin/settings/beta-policy", + settings, + ); + return data; } // --- Web Search Emulation Config --- export interface WebSearchProviderConfig { - type: 'brave' | 'tavily' - api_key: string - api_key_configured: boolean - quota_limit: number | null - subscribed_at: number | null - quota_used?: number - proxy_id: number | null - expires_at: number | null + type: "brave" | "tavily"; + api_key: string; + api_key_configured: boolean; + quota_limit: number | null; + subscribed_at: number | null; + quota_used?: number; + proxy_id: number | null; + expires_at: number | null; } export interface WebSearchEmulationConfig { - enabled: boolean - providers: WebSearchProviderConfig[] + enabled: boolean; + providers: WebSearchProviderConfig[]; } export interface WebSearchTestResult { - provider: string - results: { url: string; title: string; snippet: string; page_age?: string }[] - query: string + provider: string; + results: { url: string; title: string; snippet: string; page_age?: string }[]; + query: string; } export async function getWebSearchEmulationConfig(): Promise { const { data } = await apiClient.get( - '/admin/settings/web-search-emulation' - ) - return data + "/admin/settings/web-search-emulation", + ); + return data; } export async function updateWebSearchEmulationConfig( - config: WebSearchEmulationConfig + config: WebSearchEmulationConfig, ): Promise { const { data } = await apiClient.put( - '/admin/settings/web-search-emulation', - config - ) - return data + "/admin/settings/web-search-emulation", + config, + ); + return data; } export async function testWebSearchEmulation( - query: string + query: string, ): Promise { const { data } = await apiClient.post( - '/admin/settings/web-search-emulation/test', - { query } - ) - return data + "/admin/settings/web-search-emulation/test", + { query }, + ); + return data; } -export async function resetWebSearchUsage( - payload: { provider_type: string } -): Promise { - await apiClient.post('/admin/settings/web-search-emulation/reset-usage', payload) +export async function resetWebSearchUsage(payload: { + provider_type: string; +}): Promise { + await apiClient.post( + "/admin/settings/web-search-emulation/reset-usage", + payload, + ); } export const settingsAPI = { @@ -771,7 +901,7 @@ export const settingsAPI = { getWebSearchEmulationConfig, updateWebSearchEmulationConfig, testWebSearchEmulation, - resetWebSearchUsage -} + resetWebSearchUsage, +}; -export default settingsAPI +export default settingsAPI; diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index e56afe5f..9612d9f9 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3,7 +3,9 @@
-
+
@@ -15,7 +17,10 @@ v-for="tab in settingsTabs" :key="tab.key" type="button" - :class="['settings-tab', activeTab === tab.key && 'settings-tab-active']" + :class="[ + 'settings-tab', + activeTab === tab.key && 'settings-tab-active', + ]" @click="activeTab = tab.key" > @@ -28,208 +33,62 @@
- -
-
-

- {{ t('admin.settings.adminApiKey.title') }} -

-

- {{ t('admin.settings.adminApiKey.description') }} -

-
-
- + +
-
- -

- {{ t('admin.settings.adminApiKey.securityWarning') }} -

-
-
- - -
-
- {{ t('common.loading') }} -
- - -
- - {{ t('admin.settings.adminApiKey.notConfigured') }} - - +

+ {{ t("admin.settings.adminApiKey.title") }} +

+

+ {{ t("admin.settings.adminApiKey.description") }} +

- - -
-
-
- - - {{ adminApiKeyMasked }} - -
-
- - -
-
- - +
+
-

- {{ t('admin.settings.adminApiKey.keyWarning') }} -

-
- - {{ newAdminApiKey }} - - -
-

- {{ t('admin.settings.adminApiKey.usage') }} -

-
-
-
-
-
- - -
- - -
-
-

- {{ t('admin.settings.overloadCooldown.title') }} -

-

- {{ t('admin.settings.overloadCooldown.description') }} -

-
-
-
-
- {{ t('common.loading') }} -
- - -
-
- - -
-
-

- {{ t('admin.settings.streamTimeout.title') }} -

-

- {{ t('admin.settings.streamTimeout.description') }} -

-
-
- -
-
- {{ t('common.loading') }} -
- - +
+ - -
-
-

- {{ t('admin.settings.rectifier.title') }} -

-

- {{ t('admin.settings.rectifier.description') }} -

-
-
- -
-
- {{ t('common.loading') }} + +
+ +
+
+

+ {{ t("admin.settings.overloadCooldown.title") }} +

+

+ {{ t("admin.settings.overloadCooldown.description") }} +

- -
+
-
+ - -
-
-

- {{ t('admin.settings.turnstile.title') }} -

-

- {{ t('admin.settings.turnstile.description') }} -

-
-
- -
-
- -

- {{ t('admin.settings.turnstile.enableTurnstileHint') }} -

-
- -
- - + +
+ +
-
+

+ {{ t("admin.settings.registration.title") }} +

+

+ {{ t("admin.settings.registration.description") }} +

+
+
+ +
- - -

- {{ t('admin.settings.turnstile.siteKeyHint') }} - {{ t('admin.settings.turnstile.cloudflareDashboard') }} + +

+ {{ + t("admin.settings.registration.enableRegistrationHint") + }}

+ +
+ + +
- - -

- {{ - form.turnstile_secret_key_configured - ? t('admin.settings.turnstile.secretKeyConfiguredHint') - : t('admin.settings.turnstile.secretKeyHint') - }} + +

+ {{ t("admin.settings.registration.emailVerificationHint") }}

+
-
-
-
- -
-
-

- {{ t('admin.settings.linuxdo.title') }} -

-

- {{ t('admin.settings.linuxdo.description') }} -

-
-
-
-
+ +
-

- {{ t('admin.settings.linuxdo.enableHint') }} +

+ {{ + t("admin.settings.registration.emailSuffixWhitelistHint") + }} +

+
+
+ + @ + {{ suffix }} + + + +
+ @ + +
+
+
+

+ {{ + t( + "admin.settings.registration.emailSuffixWhitelistInputHint", + ) + }}

- -
-
-
+ +
- - -

- {{ t('admin.settings.linuxdo.clientIdHint') }} + +

+ {{ t("admin.settings.registration.promoCodeHint") }}

+ +
+ +
- - -

- {{ - form.linuxdo_connect_client_secret_configured - ? t('admin.settings.linuxdo.clientSecretConfiguredHint') - : t('admin.settings.linuxdo.clientSecretHint') - }} + +

+ {{ t("admin.settings.registration.invitationCodeHint") }}

- -
- - -
- - - {{ linuxdoRedirectUrlSuggestion }} - -
-

- {{ t('admin.settings.linuxdo.redirectUrlHint') }} + +

+ +
+
+ +

+ {{ t("admin.settings.registration.passwordResetHint") }}

+
-
-
-
- - -
-
-

- {{ t('admin.settings.oidc.title') }} -

-

- {{ t('admin.settings.oidc.description') }} -

-
-
-
-
- -

- {{ t('admin.settings.oidc.enableHint') }} + +

+ + +

+ {{ t("admin.settings.registration.frontendUrlHint") }}

- -
- -
-
-
- - -
- -
- - -
+ +
- - -

- {{ - form.oidc_connect_client_secret_configured - ? t('admin.settings.oidc.clientSecretConfiguredHint') - : t('admin.settings.oidc.clientSecretHint') - }} + +

+ {{ t("admin.settings.registration.totpHint") }} +

+ +

+ {{ t("admin.settings.registration.totpKeyNotConfigured") }}

+
+
+
-
-
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
- -
- - -
-
- -
+ +
+
+

+ {{ t("admin.settings.turnstile.title") }} +

+

+ {{ t("admin.settings.turnstile.description") }} +

+
+
+ +
- - -

- {{ t('admin.settings.oidc.scopesHint') }} + +

+ {{ t("admin.settings.turnstile.enableTurnstileHint") }}

+ +
-
- - -
- - + +

+ {{ t("admin.settings.turnstile.siteKeyHint") }} + {{ + t("admin.settings.turnstile.cloudflareDashboard") + }} +

+
+
+ + {{ t("admin.settings.turnstile.secretKey") }} + + +

+ {{ + form.turnstile_secret_key_configured + ? t( + "admin.settings.turnstile.secretKeyConfiguredHint", + ) + : t("admin.settings.turnstile.secretKeyHint") + }} +

-

- {{ t('admin.settings.oidc.redirectUrlHint') }} -

-
- -
- - -

- {{ t('admin.settings.oidc.frontendRedirectUrlHint') }} -

+
+
-
-
- - -
- -
- - -
- + +
+
+

+ {{ t("admin.settings.linuxdo.title") }} +

+

+ {{ t("admin.settings.linuxdo.description") }} +

+
+
+
- - + +

+ {{ t("admin.settings.linuxdo.enableHint") }} +

+
-
-
+
+
-
- -
-
-
- -
-
-
-
+
+
-
-
- - -
- + +
+
+

+ {{ localText("微信登录", "WeChat Connect") }} +

+

+ {{ + localText( + "用于微信开放平台或公众号/小程序的第三方登录配置。", + "Third-party login configuration for WeChat Open Platform or Official Account / Mini Program.", + ) + }} +

+
+
+
- - -
- -
- - -
-
-
-
-
-
- - -
- -
-
-

- {{ t('admin.settings.defaults.title') }} -

-

- {{ t('admin.settings.defaults.description') }} -

-
-
-
-
- - -

- {{ t('admin.settings.defaults.defaultBalanceHint') }} -

-
-
- - -

- {{ t('admin.settings.defaults.defaultConcurrencyHint') }} -

-
-
- -
-
-
- +

- {{ t('admin.settings.defaults.defaultSubscriptionsHint') }} + {{ + localText( + "开启后可使用微信第三方登录回调与授权配置。", + "Enable this to configure WeChat OAuth callbacks and authorization.", + ) + }}

- +
- {{ t('admin.settings.defaults.defaultSubscriptionsEmpty') }} -
- -
-
+
- - -
-
-
-
- -
-
-
-
-
-
- -
-
-

- {{ localText('认证来源默认值', 'Auth Source Defaults') }} -

-

- {{ - localText( - '按注册来源配置新用户默认余额、并发、订阅与授权策略。', - 'Configure per-source default balance, concurrency, subscriptions, and grant rules.' - ) - }} -

-
-
-
-
- -

- {{ - localText( - '启用后,Linux DO、OIDC、微信注册缺少邮箱时必须先补充邮箱地址。', - 'When enabled, Linux DO, OIDC, and WeChat signups must provide an email before account creation.' - ) - }} -

-
- -
- -
-
-
-
{{ authSource.title }}
-

- {{ authSource.description }} -

-
-
-
+
-
-
- -
-
-
- -

+ +

+ {{ + localText( + "open 对应微信开放平台,mp 对应公众号/小程序授权。", + "open maps to WeChat Open Platform, mp maps to Official Account / Mini Program authorization.", + ) + }} +

+
-
-
- -

- {{ - localText( - '来源首次绑定到现有账号时发放默认权益。', - 'Grant default entitlements when the source is first bound to an existing user.' - ) - }} -

-
- +
+
+ + +

+ {{ + localText( + "留空时会按模式自动回填默认值。", + "Leave empty to use the default scope for the selected mode.", + ) + }} +

-
-
-
-
- -

+

+ + +
+ + + {{ wechatRedirectUrlSuggestion }} +
-
+
-
+ + +

{{ localText( - '当前来源未配置专属默认订阅。', - 'No source-specific default subscriptions configured.' + "通常用于前端路由回调地址,需与后端配置保持一致。", + "Usually the frontend route callback path; keep it aligned with the backend.", ) }} -

- -
-
-
- - -
-
- - -
-
- -
-
-
+

-
-
- -
- -
-
-

- {{ t('admin.settings.claudeCode.title') }} -

-

- {{ t('admin.settings.claudeCode.description') }} -

-
-
-
- - -

- {{ t('admin.settings.claudeCode.minVersionHint') }} -

-
-
- - -

- {{ t('admin.settings.claudeCode.maxVersionHint') }} + +

+
+

+ {{ t("admin.settings.oidc.title") }} +

+

+ {{ t("admin.settings.oidc.description") }}

-
-
- - -
-
-

- {{ t('admin.settings.scheduling.title') }} -

-

- {{ t('admin.settings.scheduling.description') }} -

-
-
-
+
- -

- {{ t('admin.settings.scheduling.allowUngroupedKeyHint') }} + +

+ {{ t("admin.settings.oidc.enableHint") }}

- +
-
-
- -

- {{ - localText( - '默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑,不代表上游 OpenAI 官方能力。', - 'Disabled by default. When enabled, this only changes the gateway\'s experimental account-selection policy for OpenAI traffic; it does not indicate an upstream OpenAI capability.' - ) - }} -

-
- -
-
-
-
+
+
+
+ + +
- -
-
-

- {{ t('admin.settings.gatewayForwarding.title') }} -

-

- {{ t('admin.settings.gatewayForwarding.description') }} -

-
-
- -
-
- -

- {{ t('admin.settings.gatewayForwarding.fingerprintUnificationHint') }} -

-
- -
+
+ + +
- -
-
- -

- {{ t('admin.settings.gatewayForwarding.metadataPassthroughHint') }} -

-
- -
+
+ + +

+ {{ + form.oidc_connect_client_secret_configured + ? t("admin.settings.oidc.clientSecretConfiguredHint") + : t("admin.settings.oidc.clientSecretHint") + }} +

+
+
- -
-
- -

- {{ t('admin.settings.gatewayForwarding.cchSigningHint') }} -

-
- -
-
-
- -
-
-

- {{ t('admin.settings.webSearchEmulation.title') }} -

-

- {{ t('admin.settings.webSearchEmulation.description') }} -

-
-
- -
-
- -

- {{ t('admin.settings.webSearchEmulation.enabledHint') }} -

-
- -
+
+
+ + +
- -
-
- - -
+
+ + +
-
- {{ t('admin.settings.webSearchEmulation.noProviders') }} -
+
+ + +
-
- -
-
- + - +
+ +
+ + +
+ +
+ + - - - {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit != null && provider.quota_limit > 0 ? provider.quota_limit : '∞' }} - - - {{ t('admin.settings.webSearchEmulation.apiKeyConfigured') }} -
-
- -
- +
- -
- -
- - -
+ + +

+ {{ t("admin.settings.oidc.scopesHint") }} +

+
+ +
+ + +
+ + + {{ oidcRedirectUrlSuggestion }} +
+

+ {{ t("admin.settings.oidc.redirectUrlHint") }} +

+
+ +
+ + +

+ {{ t("admin.settings.oidc.frontendRedirectUrlHint") }} +

+
+
+ +
+
+ + +
+ +
+ + +
+ +
+ +
+
- -
+
+
- - -

{{ t('admin.settings.webSearchEmulation.quotaLimitHint') }}

+
+ +
+ +
- - -

{{ t('admin.settings.webSearchEmulation.subscribedAtHint') }}

+
+
- -
- {{ t('admin.settings.webSearchEmulation.quotaUsage') }}: -
-
+
+
+
-
- {{ provider.quota_used ?? 0 }} / {{ provider.quota_limit != null && provider.quota_limit > 0 ? provider.quota_limit : '∞' }} -
+
+ +
+
+ +
- -
-
- - -
- + {{ t("admin.settings.oidc.userinfoIdPath") }} + +
-
-
-
-
-
- -
-
-

- {{ t('admin.settings.webSearchEmulation.testResultTitle') }} -

-
- - -
- -
-

- {{ t('admin.settings.webSearchEmulation.testResultProvider') }}: {{ wsTestResult.provider }} -

-
- {{ t('admin.settings.webSearchEmulation.testNoResults') }} -
-
- {{ r.title }} -

{{ r.snippet }}

+
+ + +
+
-
- -
+ -
- - -
- -
-
-

- {{ t('admin.settings.site.title') }} -

-

- {{ t('admin.settings.site.description') }} -

-
-
- + +
+ +
-
-

- {{ t('admin.settings.site.backendMode') }} -

-

- {{ t('admin.settings.site.backendModeDescription') }} -

-
- -
- -
-
- - -

- {{ t('admin.settings.site.siteNameHint') }} -

-
-
- - -

- {{ t('admin.settings.site.siteSubtitleHint') }} -

-
-
- - -
- - -

- {{ t('admin.settings.site.apiBaseUrlHint') }} +

+ {{ t("admin.settings.defaults.title") }} +

+

+ {{ t("admin.settings.defaults.description") }}

- - -
-

- {{ t('admin.settings.site.tablePreferencesTitle') }} -

-

- {{ t('admin.settings.site.tablePreferencesDescription') }} -

-
+
+
-
-
-
- - -
- -

- {{ t('admin.settings.site.customEndpoints.description') }} -

-
-
-
- - {{ t('admin.settings.site.customEndpoints.itemLabel', { n: index + 1 }) }} - +
+
+
+ +

+ {{ + t("admin.settings.defaults.defaultSubscriptionsHint") + }} +

+
+ +
+ +
+ {{ t("admin.settings.defaults.defaultSubscriptionsEmpty") }} +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+
+
+
+ +
+
+

+ {{ localText("认证来源默认值", "Auth Source Defaults") }} +

+

+ {{ + localText( + "按注册来源配置新用户默认余额、并发、订阅与授权策略。", + "Configure per-source default balance, concurrency, subscriptions, and grant rules.", + ) + }} +

+
+
+
+
+ +

+ {{ + localText( + "启用后,Linux DO、OIDC、微信注册缺少邮箱时必须先补充邮箱地址。", + "When enabled, Linux DO, OIDC, and WeChat signups must provide an email before account creation.", + ) + }} +

+
+ +
+ +
+
+
+
+ {{ authSource.title }} +
+

+ {{ authSource.description }} +

+
+ +
+
+ + +
+
+ + +
+
+ +
+
+
+ +

+ {{ + localText( + "来源首次注册成功后立即发放默认权益。", + "Grant default entitlements immediately after signup.", + ) + }} +

+
+ +
+ +
+
+ +

+ {{ + localText( + "来源首次绑定到现有账号时发放默认权益。", + "Grant default entitlements when the source is first bound to an existing user.", + ) + }} +

+
+ +
+
+ +
+
+
+ +

+ {{ + localText( + "仅对当前认证来源生效,未配置时不追加来源专属订阅。", + "Applies only to this auth source. Leave empty to skip source-specific subscriptions.", + ) + }} +

+
+ +
+ +
+ {{ + localText( + "当前来源未配置专属默认订阅。", + "No source-specific default subscriptions configured.", + ) + }} +
+ +
+
+
+ + +
+
+ + +
+
+ +
+
+
+
+
+
+
+
+
+ + + +
+ +
+
+

+ {{ t("admin.settings.claudeCode.title") }} +

+

+ {{ t("admin.settings.claudeCode.description") }} +

+
+
+
+ + +

+ {{ t("admin.settings.claudeCode.minVersionHint") }} +

+
+
+ + +

+ {{ t("admin.settings.claudeCode.maxVersionHint") }} +

+
+
+
+ + +
+
+

+ {{ t("admin.settings.scheduling.title") }} +

+

+ {{ t("admin.settings.scheduling.description") }} +

+
+
+
+
+
+ +

+ {{ t("admin.settings.scheduling.allowUngroupedKeyHint") }} +

+
+ +
+ +
+
+ +

+ {{ + localText( + "默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑,不代表上游 OpenAI 官方能力。", + "Disabled by default. When enabled, this only changes the gateway's experimental account-selection policy for OpenAI traffic; it does not indicate an upstream OpenAI capability.", + ) + }} +

+
+ +
+
+
+
+ + +
+
+

+ {{ t("admin.settings.gatewayForwarding.title") }} +

+

+ {{ t("admin.settings.gatewayForwarding.description") }} +

+
+
+ +
+
+ +

+ {{ + t( + "admin.settings.gatewayForwarding.fingerprintUnificationHint", + ) + }} +

+
+ +
+ + +
+
+ +

+ {{ + t( + "admin.settings.gatewayForwarding.metadataPassthroughHint", + ) + }} +

+
+ +
+ + +
+
+ +

+ {{ t("admin.settings.gatewayForwarding.cchSigningHint") }} +

+
+ +
+
+
+ +
+
+

+ {{ t("admin.settings.webSearchEmulation.title") }} +

+

+ {{ t("admin.settings.webSearchEmulation.description") }} +

+
+
+ +
+
+ +

+ {{ t("admin.settings.webSearchEmulation.enabledHint") }} +

+
+ +
+ + +
+
+ + +
+ +
+ {{ t("admin.settings.webSearchEmulation.noProviders") }} +
+ +
+ +
+
+ + + + +
+ + +
+
+
+ + +
+
+ + +

+ {{ + t( + "admin.settings.webSearchEmulation.quotaLimitHint", + ) + }} +

+
+
+ + +

+ {{ + t( + "admin.settings.webSearchEmulation.subscribedAtHint", + ) + }} +

+
+
+ + +
+ {{ + t("admin.settings.webSearchEmulation.quotaUsage") + }}: +
+
+
+
+ {{ provider.quota_used ?? 0 }} / + {{ + provider.quota_limit != null && + provider.quota_limit > 0 + ? provider.quota_limit + : "∞" + }} + +
+ + +
+
+ + +
+ +
+
+
+
+
+
+ + +
+
+

+ {{ t("admin.settings.webSearchEmulation.testResultTitle") }} +

+
+ + +
+ +
+

+ {{ + t("admin.settings.webSearchEmulation.testResultProvider") + }}: {{ wsTestResult.provider }} +

+
+ {{ t("admin.settings.webSearchEmulation.testNoResults") }} +
+
+ {{ r.title }} +

+ {{ r.snippet }} +

+
+
+
+ +
+
+
+
+ + + +
+ +
+
+

+ {{ t("admin.settings.site.title") }} +

+

+ {{ t("admin.settings.site.description") }} +

+
+
+ +
+
+

+ {{ t("admin.settings.site.backendMode") }} +

+

+ {{ t("admin.settings.site.backendModeDescription") }} +

+
+ +
+ +
+
+ + +

+ {{ t("admin.settings.site.siteNameHint") }} +

+
+
+ + +

+ {{ t("admin.settings.site.siteSubtitleHint") }} +

+
+
+ + +
+ + +

+ {{ t("admin.settings.site.apiBaseUrlHint") }} +

+
+ + +
+

+ {{ t("admin.settings.site.tablePreferencesTitle") }} +

+

+ {{ t("admin.settings.site.tablePreferencesDescription") }} +

+
+
+ + +

+ {{ t("admin.settings.site.tableDefaultPageSizeHint") }} +

+
+
+ + +

+ {{ t("admin.settings.site.tablePageSizeOptionsHint") }} +

+
+
+
+ + +
+ +

+ {{ t("admin.settings.site.customEndpoints.description") }} +

+ +
+
+
+ + {{ + t("admin.settings.site.customEndpoints.itemLabel", { + n: index + 1, + }) + }} + + +
+
+
+ + +
+
+ + +
+
+ + +
+
+
+
+ + +
+ + +
+ + +

+ {{ t("admin.settings.site.contactInfoHint") }} +

+
+ + +
+ + +

+ {{ t("admin.settings.site.docUrlHint") }} +

+
+ + +
+ + +
+ + +
+ + +

+ {{ t("admin.settings.site.homeContentHint") }} +

+ +

+ {{ t("admin.settings.site.homeContentIframeWarning") }} +

+
+ + +
+
+ +

+ {{ t("admin.settings.site.hideCcsImportButtonHint") }} +

+
+ +
+
+
+ + +
+
+

+ {{ t("admin.settings.customMenu.title") }} +

+

+ {{ t("admin.settings.customMenu.description") }} +

+
+
+ +
+
+ + {{ + t("admin.settings.customMenu.itemLabel", { n: index + 1 }) + }} + +
+ + + + +
-
-
- +
+ +
+ +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+ + +
+
+
+ + + +
+
+
+ + + + +
+ +
+
+

+ {{ t("admin.settings.payment.title") }} +

+

+ {{ t("admin.settings.payment.description") }} + + + + + {{ t("admin.settings.payment.configGuide") }} + +

+
+
+ +
+
+ +

+ {{ t("admin.settings.payment.enabledHint") }} +

+
+ +
+
+
- -
- - -
+ + +
- -
- - -

- {{ t('admin.settings.site.homeContentHint') }} -

- -

- {{ t('admin.settings.site.homeContentIframeWarning') }} -

+
+ +
+
+
+ +
+

+ {{ t("admin.settings.emailTabDisabledTitle") }} +

+

+ {{ t("admin.settings.emailTabDisabledHint") }} +

+
+
+
- + +
- -

- {{ t('admin.settings.site.hideCcsImportButtonHint') }} +

+ {{ t("admin.settings.smtp.title") }} +

+

+ {{ t("admin.settings.smtp.description") }}

- +
-
-
- - - -
-
-

- {{ t('admin.settings.customMenu.title') }} -

-

- {{ t('admin.settings.customMenu.description') }} -

-
-
- -
-
- - {{ t('admin.settings.customMenu.itemLabel', { n: index + 1 }) }} - -
- - - - - - -
-
- -
- +
+
-
- -
- - -
- - -
-
- - -
-
-
- - - -
-
- -
- - - -
- - -
-
-

{{ t('admin.settings.payment.title') }}

-

- {{ t('admin.settings.payment.description') }} - - - {{ t('admin.settings.payment.configGuide') }} - -

-
-
- -
-
- -

{{ t('admin.settings.payment.enabledHint') }}

-
- -
- +
-
- - - - -
-
- -
-
-
- -
-

- {{ t('admin.settings.emailTabDisabledTitle') }} -

-

- {{ t('admin.settings.emailTabDisabledHint') }} -

+ +
+
+

+ {{ t("admin.settings.testEmail.title") }} +

+

+ {{ t("admin.settings.testEmail.description") }} +

+
+
+
+
+ + +
+
-
- - -
-
-
-

- {{ t('admin.settings.smtp.title') }} -

+ +
+
+

+ {{ t("admin.settings.balanceNotify.title") }} +

- {{ t('admin.settings.smtp.description') }} + {{ t("admin.settings.balanceNotify.description") }}

- -
-
-
-
- - -
-
- - -
-
- - +
+
+ +
-
- - -

- {{ - form.smtp_password_configured - ? t('admin.settings.smtp.passwordConfiguredHint') - : t('admin.settings.smtp.passwordHint') - }} +

+ +
+ $ + +
+

+ {{ t("admin.settings.balanceNotify.thresholdHint") }}

- - -
-
- + +

+ {{ t("admin.settings.balanceNotify.rechargeUrlHint") }} +

+
- + +
-
- -

- {{ t('admin.settings.smtp.useTlsHint') }} -

-
- +

+ {{ t("admin.settings.quotaNotify.title") }} +

+

+ {{ t("admin.settings.quotaNotify.description") }} +

- -
-
- - -
-
-

- {{ t('admin.settings.testEmail.title') }} -

-

- {{ t('admin.settings.testEmail.description') }} -

-
-
-
-
- - -
- -
-
-
- -
-
-

- {{ t('admin.settings.balanceNotify.title') }} -

-

- {{ t('admin.settings.balanceNotify.description') }} -

-
-
-
- - -
-
- -
- $ - +
-

{{ t('admin.settings.balanceNotify.thresholdHint') }}

-
-
- - -

{{ t('admin.settings.balanceNotify.rechargeUrlHint') }}

-
-
-
- - -
-
-

- {{ t('admin.settings.quotaNotify.title') }} -

-

- {{ t('admin.settings.quotaNotify.description') }} -

-
-
-
- - -
-
- -
-
- - - +
+
- +

+ {{ t("admin.settings.quotaNotify.emailsHint") }} +

-

{{ t('admin.settings.quotaNotify.emailsHint') }}

-
+
@@ -3070,8 +4644,17 @@
-
@@ -3104,22 +4691,32 @@ @close="showProviderDialog = false" @save="handleSaveProvider" /> - +
diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue index 67409a7c..2e3db61b 100644 --- a/frontend/src/components/admin/account/AccountTestModal.vue +++ b/frontend/src/components/admin/account/AccountTestModal.vue @@ -55,12 +55,12 @@ />
-
+